Commit 29398e70 authored by danyao12's avatar danyao12
Browse files

update 52 examples w/ mqa/gqa

parent 617bdf3f
...@@ -280,7 +280,8 @@ int run(int argc, char* argv[]) ...@@ -280,7 +280,8 @@ int run(int argc, char* argv[])
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 4;
ck::index_t G1 = 6; ck::index_t G1 = 6; // h_q
ck::index_t G2 = 6; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -299,7 +300,7 @@ int run(int argc, char* argv[]) ...@@ -299,7 +300,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -311,20 +312,21 @@ int run(int argc, char* argv[]) ...@@ -311,20 +312,21 @@ int run(int argc, char* argv[])
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -342,6 +344,7 @@ int run(int argc, char* argv[]) ...@@ -342,6 +344,7 @@ int run(int argc, char* argv[])
std::cout << "O: " << O << std::endl; std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl; std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl; std::cout << "G1: " << G1 << std::endl;
std::cout << "G2: " << G2 << std::endl;
std::cout << "alpha: " << alpha << std::endl; std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl; std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl; std::cout << "output_permute: " << output_permute << std::endl;
...@@ -357,17 +360,17 @@ int run(int argc, char* argv[]) ...@@ -357,17 +360,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
...@@ -386,6 +389,18 @@ int run(int argc, char* argv[]) ...@@ -386,6 +389,18 @@ int run(int argc, char* argv[])
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
...@@ -403,6 +418,8 @@ int run(int argc, char* argv[]) ...@@ -403,6 +418,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -411,6 +428,8 @@ int run(int argc, char* argv[]) ...@@ -411,6 +428,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -491,8 +510,8 @@ int run(int argc, char* argv[]) ...@@ -491,8 +510,8 @@ int run(int argc, char* argv[])
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize()); DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * kgrad_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * vgrad_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem d0grad_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d0grad_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
...@@ -533,6 +552,10 @@ int run(int argc, char* argv[]) ...@@ -533,6 +552,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
...@@ -580,6 +603,10 @@ int run(int argc, char* argv[]) ...@@ -580,6 +603,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
...@@ -640,11 +667,19 @@ int run(int argc, char* argv[]) ...@@ -640,11 +667,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
}); });
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) { d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
...@@ -787,14 +822,18 @@ int run(int argc, char* argv[]) ...@@ -787,14 +822,18 @@ int run(int argc, char* argv[])
#endif #endif
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_gs_ms_ns_lengths, Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides); d0_gs_ms_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_gs_ms_ns_lengths, Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides); d0_gs_ms_ns_strides);
......
...@@ -275,6 +275,7 @@ int run(int argc, char* argv[]) ...@@ -275,6 +275,7 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM); float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.0; float p_drop = 0.0;
int h_ratio = 1; // G1 / G2
bool input_permute = true; bool input_permute = true;
bool output_permute = true; bool output_permute = true;
...@@ -292,25 +293,26 @@ int run(int argc, char* argv[]) ...@@ -292,25 +293,26 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stof(argv[4]); p_drop = std::stof(argv[4]);
h_ratio = std::stof(argv[5]);
input_permute = std::stoi(argv[5]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[6]); output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4: p_drop\n");
printf("arg10: scale (alpha)\n"); printf("arg5: h_ratio\n");
printf("arg11 to 12: input / output permute\n"); printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -382,24 +384,25 @@ int run(int argc, char* argv[]) ...@@ -382,24 +384,25 @@ int run(int argc, char* argv[])
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 4 + 1; int G0 = rand() % 4 + 1;
int G1 = rand() % 4 + 1; int G2 = rand() % 4 + 1;
int G1 = G2 * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
...@@ -418,6 +421,17 @@ int run(int argc, char* argv[]) ...@@ -418,6 +421,17 @@ int run(int argc, char* argv[])
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
...@@ -439,6 +453,10 @@ int run(int argc, char* argv[]) ...@@ -439,6 +453,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
d0_gs_ms_ns_lengths, d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides, d0_gs_ms_ns_strides,
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
...@@ -464,6 +482,8 @@ int run(int argc, char* argv[]) ...@@ -464,6 +482,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
if(i < 4) if(i < 4)
{ {
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
...@@ -473,6 +493,8 @@ int run(int argc, char* argv[]) ...@@ -473,6 +493,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
} }
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -558,14 +580,22 @@ int run(int argc, char* argv[]) ...@@ -558,14 +580,22 @@ int run(int argc, char* argv[])
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
}); });
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) { d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
}); });
q_g_m_ks.push_back(q_g_m_k); q_g_m_ks.push_back(q_g_m_k);
...@@ -586,6 +616,8 @@ int run(int argc, char* argv[]) ...@@ -586,6 +616,8 @@ int run(int argc, char* argv[])
z_tensors.push_back(z_gs_ms_ns); z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
kgrad_tensors.push_back(kgrad_gs_ns_ks);
vgrad_tensors.push_back(vgrad_gs_os_ns);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
...@@ -602,12 +634,12 @@ int run(int argc, char* argv[]) ...@@ -602,12 +634,12 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back( qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize())); sizeof(OutputDataType) * kgrad_gs_ns_ks.GetElementSpaceSize()));
d0grad_tensors_device.emplace_back(std::make_unique<DeviceMem>( d0grad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(Acc0BiasDataType) * d0_gs_ms_ns.GetElementSpaceSize())); sizeof(Acc0BiasDataType) * d0_gs_ms_ns.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize())); sizeof(OutputDataType) * vgrad_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data()); q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
...@@ -652,6 +684,7 @@ int run(int argc, char* argv[]) ...@@ -652,6 +684,7 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -700,6 +733,7 @@ int run(int argc, char* argv[]) ...@@ -700,6 +733,7 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
...@@ -713,7 +747,7 @@ int run(int argc, char* argv[]) ...@@ -713,7 +747,7 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G1 = v_tensors[i].GetLengths()[1]; int G1 = q_tensors[i].GetLengths()[1];
// copy z matirx data form device // copy z matirx data form device
z_tensors_device[i]->FromDevice(z_tensors[i].mData.data()); z_tensors_device[i]->FromDevice(z_tensors[i].mData.data());
z_tensors[i].ForEach([&](auto& self, auto idx) { z_tensors[i].ForEach([&](auto& self, auto idx) {
...@@ -752,8 +786,8 @@ int run(int argc, char* argv[]) ...@@ -752,8 +786,8 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G0 = v_tensors[i].GetLengths()[0]; int G0 = q_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1]; int G1 = q_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2]; int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3]; int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2]; int M = q_tensors[i].GetLengths()[2];
...@@ -814,21 +848,21 @@ int run(int argc, char* argv[]) ...@@ -814,21 +848,21 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_tensors[i].GetLengths(), Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_tensors[i].GetLengths(),
d0_tensors[i].GetStrides()); d0_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_tensors[i].GetLengths(), Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_tensors[i].GetLengths(),
d0_tensors[i].GetStrides()); d0_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data()); qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data()); kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
......
...@@ -18,7 +18,8 @@ int run(int argc, char* argv[]) ...@@ -18,7 +18,8 @@ int run(int argc, char* argv[])
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7; ck::index_t G0 = 7;
ck::index_t G1 = 13; ck::index_t G1 = 12; // h_q
ck::index_t G2 = 12; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
...@@ -37,7 +38,7 @@ int run(int argc, char* argv[]) ...@@ -37,7 +38,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -49,20 +50,21 @@ int run(int argc, char* argv[]) ...@@ -49,20 +50,21 @@ int run(int argc, char* argv[])
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -77,17 +79,17 @@ int run(int argc, char* argv[]) ...@@ -77,17 +79,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
...@@ -323,11 +325,19 @@ int run(int argc, char* argv[]) ...@@ -323,11 +325,19 @@ int run(int argc, char* argv[])
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_g_n_o.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
}); });
d_gs_ms_ns.ForEach([&](auto& self, auto idx) { d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
......
...@@ -11,6 +11,7 @@ int run(int argc, char* argv[]) ...@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool output_permute = true; bool output_permute = true;
float p_drop = 0.2; float p_drop = 0.2;
int h_ratio = 1; // G1 / G2
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -24,22 +25,25 @@ int run(int argc, char* argv[]) ...@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stoi(argv[4]); p_drop = std::stoi(argv[4]);
input_permute = std::stoi(argv[5]); h_ratio = std::stof(argv[5]);
output_permute = std::stoi(argv[6]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 5: input / output permute\n"); printf("arg4: p_drop\n");
printf("arg5: h_ratio\n");
printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -91,7 +95,8 @@ int run(int argc, char* argv[]) ...@@ -91,7 +95,8 @@ int run(int argc, char* argv[])
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G2 = rand() % 5 + 1;
int G1 = G2 * h_ratio;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
...@@ -101,17 +106,17 @@ int run(int argc, char* argv[]) ...@@ -101,17 +106,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
...@@ -275,7 +280,8 @@ int run(int argc, char* argv[]) ...@@ -275,7 +280,8 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
...@@ -330,7 +336,8 @@ int run(int argc, char* argv[]) ...@@ -330,7 +336,8 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
...@@ -395,13 +402,20 @@ int run(int argc, char* argv[]) ...@@ -395,13 +402,20 @@ int run(int argc, char* argv[])
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
}); const size_t& g1 = idx[0] % G1;
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { const size_t& g2 = g1 / h_ratio;
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
}); });
b1_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) { d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment