Commit d10f25a0 authored by letaoqin's avatar letaoqin
Browse files

bwd biaes to bias

parent 127982f1
......@@ -515,8 +515,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_biases;
nullptr, // p_acc1_biases;
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -528,10 +528,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths,
{}, // acc1_biases_gs_ms_os_strides,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
......@@ -560,8 +560,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_biases;
nullptr, // p_acc1_biases;
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -573,10 +573,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths,
{}, // acc1_biases_gs_ms_os_strides,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
......
......@@ -437,8 +437,8 @@ int run(int argc, char* argv[])
lse_gs_ms_strides,
d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
});
int BatchCount = G0 * G1;
......
......@@ -299,11 +299,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> acc0_biases_gs_ms_ns_lengths;
std::vector<index_t> acc0_biases_gs_ms_ns_strides;
std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<index_t> acc0_bias_gs_ms_ns_strides;
std::vector<index_t> acc1_biases_gs_ms_os_lengths;
std::vector<index_t> acc1_biases_gs_ms_os_strides;
std::vector<index_t> acc1_bias_gs_ms_os_lengths;
std::vector<index_t> acc1_bias_gs_ms_os_strides;
};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
}
// D in Gemm0 C position
static auto
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
static auto
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
......@@ -756,8 +755,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases,
const std::vector<const void*>& p_acc1_biases,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -788,9 +787,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_biases.size()) ||
ck::type_convert<ck::index_t>(p_acc0_biases.size() == 0)) &&
0 == p_acc1_biases.size()))
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias.size() == 0)) &&
0 == p_acc1_bias.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -804,8 +803,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_biases.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_biases[i])
(ck::type_convert<ck::index_t>(p_acc0_bias.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias[i])
: nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
......@@ -827,8 +826,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides;
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_bias_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_bias_gs_ms_ns_strides;
}
else
{
......@@ -971,12 +970,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_n_length_stride});
}
// TODO: implement bias addition
// ignore = p_acc0_biases;
// ignore = p_acc1_biases;
// ignore = acc0_biases_gs_ms_ns_lengths;
// ignore = acc0_biases_gs_ms_ns_strides;
// ignore = acc1_biases_gs_ms_gemm1ns_lengths;
// ignore = acc1_biases_gs_ms_gemm1ns_strides;
// ignore = p_acc0_bias;
// ignore = p_acc1_bias;
// ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias_gs_ms_gemm1ns_strides;
}
// element-wise op
......@@ -1197,8 +1196,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases,
const std::vector<const void*>& p_acc1_biases,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1218,8 +1217,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads,
p_Kgrads,
p_Vgrads,
p_acc0_biases,
p_acc1_biases,
p_acc0_bias,
p_acc1_bias,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1245,8 +1244,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases,
const std::vector<const void*>& p_acc1_biases,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1266,8 +1265,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Qgrads,
p_Kgrads,
p_Vgrads,
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
p_acc0_bias, // cast in struct Argument
p_acc1_bias, // cast in struct Argument
problem_desc_vec,
a_element_op,
b_element_op,
......
......@@ -306,11 +306,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides;
std::vector<index_t> acc0_biases_gs_ms_ns_lengths;
std::vector<index_t> acc0_biases_gs_ms_ns_strides;
std::vector<index_t> acc0_bias_gs_ms_ns_lengths;
std::vector<index_t> acc0_bias_gs_ms_ns_strides;
std::vector<index_t> acc1_biases_gs_ms_os_lengths;
std::vector<index_t> acc1_biases_gs_ms_os_strides;
std::vector<index_t> acc1_bias_gs_ms_os_lengths;
std::vector<index_t> acc1_bias_gs_ms_os_strides;
};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -497,22 +497,21 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
}
// D in Gemm0 C position
static auto
MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
static auto
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
......@@ -764,8 +763,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases,
const std::vector<const void*>& p_acc1_biases,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -796,9 +795,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_biases.size()) ||
ck::type_convert<ck::index_t>(p_acc0_biases.size() == 0)) &&
0 == p_acc1_biases.size()))
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias.size() == 0)) &&
0 == p_acc1_bias.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -812,8 +811,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_biases.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_biases[i])
(ck::type_convert<ck::index_t>(p_acc0_bias.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_bias[i])
: nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
......@@ -835,8 +834,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides;
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_bias_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_bias_gs_ms_ns_strides;
}
else
{
......@@ -979,12 +978,12 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_n_length_stride});
}
// TODO: implement bias addition
// ignore = p_acc0_biases;
// ignore = p_acc1_biases;
// ignore = acc0_biases_gs_ms_ns_lengths;
// ignore = acc0_biases_gs_ms_ns_strides;
// ignore = acc1_biases_gs_ms_gemm1ns_lengths;
// ignore = acc1_biases_gs_ms_gemm1ns_strides;
// ignore = p_acc0_bias;
// ignore = p_acc1_bias;
// ignore = acc0_bias_gs_ms_ns_lengths;
// ignore = acc0_bias_gs_ms_ns_strides;
// ignore = acc1_bias_gs_ms_gemm1ns_lengths;
// ignore = acc1_bias_gs_ms_gemm1ns_strides;
}
// element-wise op
......@@ -1209,8 +1208,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases,
const std::vector<const void*>& p_acc1_biases,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1230,8 +1229,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads,
p_Kgrads,
p_Vgrads,
p_acc0_biases,
p_acc1_biases,
p_acc0_bias,
p_acc1_bias,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1257,8 +1256,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Qgrads,
std::vector<void*>& p_Kgrads,
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_biases,
const std::vector<const void*>& p_acc1_biases,
const std::vector<const void*>& p_acc0_bias,
const std::vector<const void*>& p_acc1_bias,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1278,8 +1277,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Qgrads,
p_Kgrads,
p_Vgrads,
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
p_acc0_bias, // cast in struct Argument
p_acc1_bias, // cast in struct Argument
problem_desc_vec,
a_element_op,
b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment