Commit 35379cdb authored by letaoqin's avatar letaoqin
Browse files

add grad bias example

parent 2f2f5490
......@@ -374,8 +374,8 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides =
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
......@@ -396,7 +396,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<Acc0BiasDataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
......@@ -405,7 +405,7 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "d0_gs_ms_ns: " << d0_gs_ms_ns.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
......@@ -420,36 +420,35 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break;
case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// dO dot O = [0; 1; 2; ...]
break;
case 6:
......@@ -457,7 +456,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -471,7 +470,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -485,7 +484,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
......@@ -494,11 +493,11 @@ int run(int argc, char* argv[])
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem dgrad_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem d0grad_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
......@@ -517,9 +516,9 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(dgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d0_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(d0grad_device_buf.GetDeviceBuffer()),
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
......@@ -532,10 +531,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
......@@ -564,9 +563,9 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(dgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d0_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(d0grad_device_buf.GetDeviceBuffer()),
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
......@@ -579,10 +578,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
......@@ -623,7 +622,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
......@@ -645,13 +644,13 @@ int run(int argc, char* argv[])
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
d_g_m_n,
d0_g_m_n,
v_g_n_o,
alpha,
s_g_m_n,
......@@ -788,14 +787,19 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
d0grad_device_buf.FromDevice(d0grad_gs_ms_ns_device_result.mData.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
......@@ -823,6 +827,15 @@ int run(int argc, char* argv[])
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData,
......@@ -841,6 +854,12 @@ int run(int argc, char* argv[])
"error",
1e-2,
1e-2);
std::cout << "Checking d0grad:\n";
pass &= ck::utils::check_err(d0grad_gs_ms_ns_device_result.mData,
d0grad_gs_ms_ns_host_result.mData,
"error",
1e-2,
1e-2);
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
......@@ -2349,7 +2349,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment