Commit d10f25a0 authored by letaoqin's avatar letaoqin
Browse files

bwd biaes to bias

parent 127982f1
...@@ -515,8 +515,8 @@ int run(int argc, char* argv[]) ...@@ -515,8 +515,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_biases; static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_biases; nullptr, // p_acc1_bias;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -528,10 +528,10 @@ int run(int argc, char* argv[]) ...@@ -528,10 +528,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_biases_gs_ms_os_strides, {}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
...@@ -560,8 +560,8 @@ int run(int argc, char* argv[]) ...@@ -560,8 +560,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_biases; static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_biases; nullptr, // p_acc1_bias;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -573,10 +573,10 @@ int run(int argc, char* argv[]) ...@@ -573,10 +573,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_biases_gs_ms_os_strides, {}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
......
...@@ -437,8 +437,8 @@ int run(int argc, char* argv[]) ...@@ -437,8 +437,8 @@ int run(int argc, char* argv[])
lse_gs_ms_strides, lse_gs_ms_strides,
d0_gs_ms_ns_lengths, d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides, d0_gs_ms_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // acc1_bias_gs_ms_os_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // acc1_bias_gs_ms_os_strides,
}); });
int BatchCount = G0 * G1; int BatchCount = G0 * G1;
......
...@@ -299,11 +299,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -299,11 +299,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> acc0_biases_gs_ms_ns_lengths; std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<index_t> acc0_biases_gs_ms_ns_strides; std::vector<index_t> acc0_bias_gs_ms_ns_strides;
std::vector<index_t> acc1_biases_gs_ms_os_lengths; std::vector<index_t> acc1_bias_gs_ms_os_lengths;
std::vector<index_t> acc1_biases_gs_ms_os_strides; std::vector<index_t> acc1_bias_gs_ms_os_strides;
}; };
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
} }
// D in Gemm0 C position // D in Gemm0 C position
static auto static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths, return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
static auto static auto
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths, return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -756,8 +755,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -756,8 +755,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -788,9 +787,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -788,9 +787,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_biases.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias.size()) ||
ck::type_convert<ck::index_t>(p_acc0_biases.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias.size() == 0)) &&
0 == p_acc1_biases.size())) 0 == p_acc1_bias.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -804,8 +803,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -804,8 +803,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid = const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_biases.size()) == group_count_) (ck::type_convert<ck::index_t>(p_acc0_bias.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_biases[i]) ? static_cast<const D0DataType*>(p_acc0_bias[i])
: nullptr; : nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
...@@ -827,8 +826,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -827,8 +826,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> tmp_d0_gs_ms_ns_strides; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths; tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_bias_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides; tmp_d0_gs_ms_ns_strides = problem_desc.acc0_bias_gs_ms_ns_strides;
} }
else else
{ {
...@@ -971,12 +970,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -971,12 +970,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_n_length_stride}); d0_n_length_stride});
} }
// TODO: implement bias addition // TODO: implement bias addition
// ignore = p_acc0_biases; // ignore = p_acc0_bias;
// ignore = p_acc1_biases; // ignore = p_acc1_bias;
// ignore = acc0_biases_gs_ms_ns_lengths; // ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_biases_gs_ms_ns_strides; // ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_biases_gs_ms_gemm1ns_lengths; // ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_biases_gs_ms_gemm1ns_strides; // ignore = acc1_bias_gs_ms_gemm1ns_strides;
} }
// element-wise op // element-wise op
...@@ -1197,8 +1196,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1197,8 +1196,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1218,8 +1217,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1218,8 +1217,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_biases, p_acc0_bias,
p_acc1_biases, p_acc1_bias,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1245,8 +1244,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1245,8 +1244,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1266,8 +1265,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1266,8 +1265,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_biases, // cast in struct Argument p_acc0_bias, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_bias, // cast in struct Argument
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -306,11 +306,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -306,11 +306,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> acc0_biases_gs_ms_ns_lengths; std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<index_t> acc0_biases_gs_ms_ns_strides; std::vector<index_t> acc0_bias_gs_ms_ns_strides;
std::vector<index_t> acc1_biases_gs_ms_os_lengths; std::vector<index_t> acc1_bias_gs_ms_os_lengths;
std::vector<index_t> acc1_biases_gs_ms_os_strides; std::vector<index_t> acc1_bias_gs_ms_os_strides;
}; };
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
} }
// D in Gemm0 C position // D in Gemm0 C position
static auto static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths, return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
static auto static auto
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths, MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths, return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -764,8 +763,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -764,8 +763,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -796,9 +795,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -796,9 +795,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_biases.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias.size()) ||
ck::type_convert<ck::index_t>(p_acc0_biases.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias.size() == 0)) &&
0 == p_acc1_biases.size())) 0 == p_acc1_bias.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -812,8 +811,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -812,8 +811,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid = const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_biases.size()) == group_count_) (ck::type_convert<ck::index_t>(p_acc0_bias.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_biases[i]) ? static_cast<const D0DataType*>(p_acc0_bias[i])
: nullptr; : nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
...@@ -835,8 +834,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -835,8 +834,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<index_t> tmp_d0_gs_ms_ns_strides; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths; tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_bias_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides; tmp_d0_gs_ms_ns_strides = problem_desc.acc0_bias_gs_ms_ns_strides;
} }
else else
{ {
...@@ -979,12 +978,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -979,12 +978,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_n_length_stride}); d0_n_length_stride});
} }
// TODO: implement bias addition // TODO: implement bias addition
// ignore = p_acc0_biases; // ignore = p_acc0_bias;
// ignore = p_acc1_biases; // ignore = p_acc1_bias;
// ignore = acc0_biases_gs_ms_ns_lengths; // ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_biases_gs_ms_ns_strides; // ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_biases_gs_ms_gemm1ns_lengths; // ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_biases_gs_ms_gemm1ns_strides; // ignore = acc1_bias_gs_ms_gemm1ns_strides;
} }
// element-wise op // element-wise op
...@@ -1209,8 +1208,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1209,8 +1208,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1230,8 +1229,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1230,8 +1229,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_biases, p_acc0_bias,
p_acc1_biases, p_acc1_bias,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1257,8 +1256,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1257,8 +1256,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads, std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads, std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases, const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_biases, const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1278,8 +1277,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1278,8 +1277,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads, p_Qgrads,
p_Kgrads, p_Kgrads,
p_Vgrads, p_Vgrads,
p_acc0_biases, // cast in struct Argument p_acc0_bias, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_bias, // cast in struct Argument
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment