"vscode:/vscode.git/clone" did not exist on "6a3bcef8ee40aaf359cb7162f057b9543ce05822"
Commit 26cc4721 authored by guangzlu's avatar guangzlu
Browse files

added z in example

parent c75a3c17
...@@ -48,6 +48,7 @@ int run(int argc, char* argv[]) ...@@ -48,6 +48,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<void*> p_z;
std::vector<void*> p_lse; std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1_m_n_k_o;
...@@ -102,6 +103,12 @@ int run(int argc, char* argv[]) ...@@ -102,6 +103,12 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_os_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides = std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M]
...@@ -114,6 +121,8 @@ int run(int argc, char* argv[]) ...@@ -114,6 +121,8 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
{}, // acc0_biases_gs_ms_ns_lengths {}, // acc0_biases_gs_ms_ns_lengths
...@@ -125,6 +134,7 @@ int run(int argc, char* argv[]) ...@@ -125,6 +134,7 @@ int run(int argc, char* argv[])
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -193,6 +203,7 @@ int run(int argc, char* argv[]) ...@@ -193,6 +203,7 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
} }
...@@ -209,6 +220,7 @@ int run(int argc, char* argv[]) ...@@ -209,6 +220,7 @@ int run(int argc, char* argv[])
p_b0, p_b0,
p_b1, p_b1,
p_c, p_c,
p_z,
p_lse, p_lse,
{}, // p_acc0_biases {}, // p_acc0_biases
{}, // p_acc1_biases {}, // p_acc1_biases
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment