Commit 35379cdb authored by letaoqin's avatar letaoqin
Browse files

add grad bias example

parent 2f2f5490
...@@ -374,8 +374,8 @@ int run(int argc, char* argv[]) ...@@ -374,8 +374,8 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides = std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
...@@ -396,7 +396,7 @@ int run(int argc, char* argv[]) ...@@ -396,7 +396,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<Acc0BiasDataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
...@@ -405,7 +405,7 @@ int run(int argc, char* argv[]) ...@@ -405,7 +405,7 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; std::cout << "d0_gs_ms_ns: " << d0_gs_ms_ns.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl; std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
...@@ -420,36 +420,35 @@ int run(int argc, char* argv[]) ...@@ -420,36 +420,35 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
case 2: case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
...@@ -457,7 +456,7 @@ int run(int argc, char* argv[]) ...@@ -457,7 +456,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -471,7 +470,7 @@ int run(int argc, char* argv[]) ...@@ -471,7 +470,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -485,7 +484,7 @@ int run(int argc, char* argv[]) ...@@ -485,7 +484,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
...@@ -494,11 +493,11 @@ int run(int argc, char* argv[]) ...@@ -494,11 +493,11 @@ int run(int argc, char* argv[])
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem dgrad_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d0grad_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
...@@ -517,9 +516,9 @@ int run(int argc, char* argv[]) ...@@ -517,9 +516,9 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias; static_cast<Acc0BiasDataType*>(d0_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias; nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(dgrad_device_buf.GetDeviceBuffer()), static_cast<Acc0BiasDataType*>(d0grad_device_buf.GetDeviceBuffer()),
nullptr, nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
...@@ -532,10 +531,10 @@ int run(int argc, char* argv[]) ...@@ -532,10 +531,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides, {}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
...@@ -564,9 +563,9 @@ int run(int argc, char* argv[]) ...@@ -564,9 +563,9 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias; static_cast<Acc0BiasDataType*>(d0_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias; nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(dgrad_device_buf.GetDeviceBuffer()), static_cast<Acc0BiasDataType*>(d0grad_device_buf.GetDeviceBuffer()),
nullptr, nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
...@@ -579,10 +578,10 @@ int run(int argc, char* argv[]) ...@@ -579,10 +578,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides, {}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
...@@ -623,7 +622,7 @@ int run(int argc, char* argv[]) ...@@ -623,7 +622,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O}); Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
...@@ -645,13 +644,13 @@ int run(int argc, char* argv[]) ...@@ -645,13 +644,13 @@ int run(int argc, char* argv[])
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
d_gs_ms_ns.ForEach([&](auto& self, auto idx) { d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
// run fwd again for y, cause z_g_m_n update // run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k, run_attention_fwd_host(q_g_m_k,
k_g_n_k, k_g_n_k,
d_g_m_n, d0_g_m_n,
v_g_n_o, v_g_n_o,
alpha, alpha,
s_g_m_n, s_g_m_n,
...@@ -788,14 +787,19 @@ int run(int argc, char* argv[]) ...@@ -788,14 +787,19 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data()); qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data()); kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data()); vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
d0grad_device_buf.FromDevice(d0grad_gs_ms_ns_device_result.mData.data());
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
...@@ -823,6 +827,15 @@ int run(int argc, char* argv[]) ...@@ -823,6 +827,15 @@ int run(int argc, char* argv[])
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]); self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
}); });
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
std::cout << "Checking qgrad:\n"; std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData, pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData, qgrad_gs_ms_ks_host_result.mData,
...@@ -841,6 +854,12 @@ int run(int argc, char* argv[]) ...@@ -841,6 +854,12 @@ int run(int argc, char* argv[])
"error", "error",
1e-2, 1e-2,
1e-2); 1e-2);
std::cout << "Checking d0grad:\n";
pass &= ck::utils::check_err(d0grad_gs_ms_ns_device_result.mData,
d0grad_gs_ms_ns_host_result.mData,
"error",
1e-2,
1e-2);
} }
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1); return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
...@@ -2349,7 +2349,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2349,7 +2349,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds // load data to lds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment