Commit 198558c5 authored by danyao12's avatar danyao12
Browse files

train mqa/gqa

parent 6a2d7c9f
......@@ -277,7 +277,7 @@ int run(int argc, char* argv[])
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6; // h_q
ck::index_t G2 = 1; // h_kv
ck::index_t G2 = 6; // h_kv
bool input_permute = false;
bool output_permute = false;
......
......@@ -305,8 +305,9 @@ int run(int argc, char* argv[])
ck::index_t M = 500; // 512
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16
ck::index_t G0 = 4;
ck::index_t G1 = 6; // h_q
ck::index_t G2 = 6; // h_kv
bool input_permute = false;
bool output_permute = false;
......@@ -325,7 +326,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 14)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -337,20 +338,21 @@ int run(int argc, char* argv[])
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
p_drop = std::stof(argv[10]);
p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[13]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg11: p_drop\n");
printf("arg12 to 13: input / output permute\n");
exit(0);
}
......@@ -368,6 +370,7 @@ int run(int argc, char* argv[])
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "G2: " << G2 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
......@@ -383,17 +386,17 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K]
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O]
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
......@@ -406,6 +409,18 @@ int run(int argc, char* argv[])
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
......@@ -424,8 +439,10 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
......@@ -612,6 +629,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
......@@ -656,8 +677,10 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os_host_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
......@@ -760,6 +783,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
......@@ -795,11 +822,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
});
z_fwd_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_fwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
......
......@@ -302,6 +302,7 @@ int run(int argc, char* argv[])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.2;
int h_ratio = 1; // G1 / G2
bool input_permute = true;
bool output_permute = true;
......@@ -319,25 +320,26 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
else if(argc == 8)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
p_drop = std::stof(argv[4]);
h_ratio = std::stof(argv[5]);
input_permute = std::stoi(argv[5]);
output_permute = std::stoi(argv[6]);
input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[7]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
printf("arg4: p_drop\n");
printf("arg5: h_ratio\n");
printf("arg6 to 7: input / output permute\n");
exit(0);
}
......@@ -412,24 +414,25 @@ int run(int argc, char* argv[])
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G1 = rand() % 4 + 1;
int G2 = rand() % 4 + 1;
int G1 = G2 * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K]
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O]
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
......@@ -442,6 +445,17 @@ int run(int argc, char* argv[])
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
......@@ -481,6 +495,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides,
lse_gs_ms_lengths,
lse_gs_ms_strides,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
......@@ -510,6 +528,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
if(i < 4)
{
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
......@@ -518,6 +538,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
}
z_fwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
z_bwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
......@@ -598,11 +620,19 @@ int run(int argc, char* argv[])
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
});
q_g_m_ks.push_back(q_g_m_k);
......@@ -624,6 +654,8 @@ int run(int argc, char* argv[])
z_bwd_tensors.push_back(z_bwd_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os);
kgrad_tensors.push_back(kgrad_gs_ns_ks);
vgrad_tensors.push_back(vgrad_gs_os_ns);
q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
......@@ -641,10 +673,10 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(OutputDataType) * kgrad_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(OutputDataType) * vgrad_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
......@@ -690,6 +722,7 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread
......@@ -737,6 +770,7 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_bwd(gemm_bwd.GetWorkSpaceSize(&argument_bwd));
......@@ -787,6 +821,7 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop, // dropout ratio
h_ratio,
{seed, offset}); // dropout random seed and offset, offset should
// be at least the number of elements on a thread
......@@ -826,6 +861,7 @@ int run(int argc, char* argv[])
QKVElementOp{},
YElementOp{},
p_drop,
h_ratio,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace_bwd_verify(gemm_bwd.GetWorkSpaceSize(&argument_bwd));
gemm_bwd.SetWorkSpacePointer(&argument_bwd,
......@@ -840,7 +876,7 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int G1 = v_tensors[i].GetLengths()[1];
int G1 = q_tensors[i].GetLengths()[1];
// copy z matirx data form device
z_fwd_tensors_device[i]->FromDevice(z_fwd_tensors[i].mData.data());
z_fwd_tensors[i].ForEach([&](auto& self, auto idx) {
......@@ -863,7 +899,7 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t,
rp_dropout);
int G0 = v_tensors[i].GetLengths()[0];
int G0 = q_tensors[i].GetLengths()[0];
int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2];
......@@ -921,10 +957,10 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_tensors[i].GetLengths(),
kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_tensors[i].GetLengths(),
vgrad_tensors[i].GetStrides());
Tensor<InputDataType> y_gs_ms_os_host_result(y_tensors[i].GetLengths(),
y_tensors[i].GetStrides());
Tensor<LSEDataType> lse_gs_ms_host_result(lse_tensors[i].GetLengths(),
......@@ -932,10 +968,10 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_tensors[i].GetLengths(),
kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_tensors[i].GetLengths(),
vgrad_tensors[i].GetStrides());
Tensor<InputDataType> y_gs_ms_os_device_result(y_tensors[i].GetLengths(),
y_tensors[i].GetStrides());
Tensor<LSEDataType> lse_gs_ms_device_result(lse_tensors[i].GetLengths(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment