Commit 198558c5 authored by danyao12's avatar danyao12
Browse files

train mqa/gqa

parent 6a2d7c9f
...@@ -277,7 +277,7 @@ int run(int argc, char* argv[]) ...@@ -277,7 +277,7 @@ int run(int argc, char* argv[])
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 4;
ck::index_t G1 = 6; // h_q ck::index_t G1 = 6; // h_q
ck::index_t G2 = 1; // h_kv ck::index_t G2 = 6; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
......
...@@ -305,8 +305,9 @@ int run(int argc, char* argv[]) ...@@ -305,8 +305,9 @@ int run(int argc, char* argv[])
ck::index_t M = 500; // 512 ck::index_t M = 500; // 512
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4;
ck::index_t G1 = 6; // 16 ck::index_t G1 = 6; // h_q
ck::index_t G2 = 6; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -325,7 +326,7 @@ int run(int argc, char* argv[]) ...@@ -325,7 +326,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -337,20 +338,21 @@ int run(int argc, char* argv[]) ...@@ -337,20 +338,21 @@ int run(int argc, char* argv[])
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -368,6 +370,7 @@ int run(int argc, char* argv[]) ...@@ -368,6 +370,7 @@ int run(int argc, char* argv[])
std::cout << "O: " << O << std::endl; std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl; std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl; std::cout << "G1: " << G1 << std::endl;
std::cout << "G2: " << G2 << std::endl;
std::cout << "alpha: " << alpha << std::endl; std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl; std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl; std::cout << "output_permute: " << output_permute << std::endl;
...@@ -383,17 +386,17 @@ int run(int argc, char* argv[]) ...@@ -383,17 +386,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
...@@ -406,6 +409,18 @@ int run(int argc, char* argv[]) ...@@ -406,6 +409,18 @@ int run(int argc, char* argv[])
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
...@@ -424,8 +439,10 @@ int run(int argc, char* argv[]) ...@@ -424,8 +439,10 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -612,6 +629,10 @@ int run(int argc, char* argv[]) ...@@ -612,6 +629,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -656,8 +677,10 @@ int run(int argc, char* argv[]) ...@@ -656,8 +677,10 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os_host_result(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os_host_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<InputDataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
...@@ -760,6 +783,10 @@ int run(int argc, char* argv[]) ...@@ -760,6 +783,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -795,11 +822,19 @@ int run(int argc, char* argv[]) ...@@ -795,11 +822,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
}); });
z_fwd_gs_ms_ns.ForEach([&](auto& self, auto idx) { z_fwd_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_fwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_fwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
......
...@@ -302,6 +302,7 @@ int run(int argc, char* argv[]) ...@@ -302,6 +302,7 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM); float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.2; float p_drop = 0.2;
int h_ratio = 1; // G1 / G2
bool input_permute = true; bool input_permute = true;
bool output_permute = true; bool output_permute = true;
...@@ -319,25 +320,26 @@ int run(int argc, char* argv[]) ...@@ -319,25 +320,26 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stof(argv[4]); p_drop = std::stof(argv[4]);
h_ratio = std::stof(argv[5]);
input_permute = std::stoi(argv[5]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[6]); output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4: p_drop\n");
printf("arg10: scale (alpha)\n"); printf("arg5: h_ratio\n");
printf("arg11 to 12: input / output permute\n"); printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -412,24 +414,25 @@ int run(int argc, char* argv[]) ...@@ -412,24 +414,25 @@ int run(int argc, char* argv[])
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 4 + 1; int G0 = rand() % 4 + 1;
int G1 = rand() % 4 + 1; int G2 = rand() % 4 + 1;
int G1 = G2 * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
...@@ -442,6 +445,17 @@ int run(int argc, char* argv[]) ...@@ -442,6 +445,17 @@ int run(int argc, char* argv[])
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
...@@ -481,6 +495,10 @@ int run(int argc, char* argv[]) ...@@ -481,6 +495,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -510,6 +528,8 @@ int run(int argc, char* argv[]) ...@@ -510,6 +528,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
if(i < 4) if(i < 4)
{ {
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
...@@ -518,6 +538,8 @@ int run(int argc, char* argv[]) ...@@ -518,6 +538,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
} }
z_fwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_fwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
z_bwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_bwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
...@@ -598,11 +620,19 @@ int run(int argc, char* argv[]) ...@@ -598,11 +620,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
}); });
q_g_m_ks.push_back(q_g_m_k); q_g_m_ks.push_back(q_g_m_k);
...@@ -624,6 +654,8 @@ int run(int argc, char* argv[]) ...@@ -624,6 +654,8 @@ int run(int argc, char* argv[])
z_bwd_tensors.push_back(z_bwd_gs_ms_ns); z_bwd_tensors.push_back(z_bwd_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
kgrad_tensors.push_back(kgrad_gs_ns_ks);
vgrad_tensors.push_back(vgrad_gs_os_ns);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
...@@ -641,10 +673,10 @@ int run(int argc, char* argv[]) ...@@ -641,10 +673,10 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back( qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize())); sizeof(OutputDataType) * kgrad_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize())); sizeof(OutputDataType) * vgrad_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
...@@ -689,7 +721,8 @@ int run(int argc, char* argv[]) ...@@ -689,7 +721,8 @@ int run(int argc, char* argv[])
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should {seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread // be at least the number of elements on a thread
...@@ -737,6 +770,7 @@ int run(int argc, char* argv[]) ...@@ -737,6 +770,7 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_bwd(gemm_bwd.GetWorkSpaceSize(&argument_bwd)); DeviceMem problem_desc_workspace_bwd(gemm_bwd.GetWorkSpaceSize(&argument_bwd));
...@@ -786,7 +820,8 @@ int run(int argc, char* argv[]) ...@@ -786,7 +820,8 @@ int run(int argc, char* argv[])
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, // dropout ratio p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should {seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread // be at least the number of elements on a thread
...@@ -826,6 +861,7 @@ int run(int argc, char* argv[]) ...@@ -826,6 +861,7 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
YElementOp{}, YElementOp{},
p_drop, p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_bwd_verify(gemm_bwd.GetWorkSpaceSize(&argument_bwd)); DeviceMem problem_desc_workspace_bwd_verify(gemm_bwd.GetWorkSpaceSize(&argument_bwd));
gemm_bwd.SetWorkSpacePointer(&argument_bwd, gemm_bwd.SetWorkSpacePointer(&argument_bwd,
...@@ -840,7 +876,7 @@ int run(int argc, char* argv[]) ...@@ -840,7 +876,7 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G1 = v_tensors[i].GetLengths()[1]; int G1 = q_tensors[i].GetLengths()[1];
// copy z matirx data form device // copy z matirx data form device
z_fwd_tensors_device[i]->FromDevice(z_fwd_tensors[i].mData.data()); z_fwd_tensors_device[i]->FromDevice(z_fwd_tensors[i].mData.data());
z_fwd_tensors[i].ForEach([&](auto& self, auto idx) { z_fwd_tensors[i].ForEach([&](auto& self, auto idx) {
...@@ -863,7 +899,7 @@ int run(int argc, char* argv[]) ...@@ -863,7 +899,7 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
int G0 = v_tensors[i].GetLengths()[0]; int G0 = q_tensors[i].GetLengths()[0];
int O = v_tensors[i].GetLengths()[2]; int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3]; int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2]; int M = q_tensors[i].GetLengths()[2];
...@@ -921,10 +957,10 @@ int run(int argc, char* argv[]) ...@@ -921,10 +957,10 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
Tensor<InputDataType> y_gs_ms_os_host_result(y_tensors[i].GetLengths(), Tensor<InputDataType> y_gs_ms_os_host_result(y_tensors[i].GetLengths(),
y_tensors[i].GetStrides()); y_tensors[i].GetStrides());
Tensor<LSEDataType> lse_gs_ms_host_result(lse_tensors[i].GetLengths(), Tensor<LSEDataType> lse_gs_ms_host_result(lse_tensors[i].GetLengths(),
...@@ -932,10 +968,10 @@ int run(int argc, char* argv[]) ...@@ -932,10 +968,10 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
Tensor<InputDataType> y_gs_ms_os_device_result(y_tensors[i].GetLengths(), Tensor<InputDataType> y_gs_ms_os_device_result(y_tensors[i].GetLengths(),
y_tensors[i].GetStrides()); y_tensors[i].GetStrides());
Tensor<LSEDataType> lse_gs_ms_device_result(lse_tensors[i].GetLengths(), Tensor<LSEDataType> lse_gs_ms_device_result(lse_tensors[i].GetLengths(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment