Commit 3185cbf9 authored by fsx950223's avatar fsx950223
Browse files

add verification

parent fe6ee651
...@@ -244,7 +244,7 @@ int run(int argc, char* argv[]) ...@@ -244,7 +244,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float K = 64; float K = 128;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
bool input_permute = false; bool input_permute = false;
...@@ -297,6 +297,12 @@ int run(int argc, char* argv[]) ...@@ -297,6 +297,12 @@ int run(int argc, char* argv[])
std::vector<DataType*> p_vgrad; std::vector<DataType*> p_vgrad;
std::vector<const DataType*> p_ygrad; std::vector<const DataType*> p_ygrad;
std::vector<Tensor<DataType>> q_g_m_ks;
std::vector<Tensor<DataType>> k_g_n_ks;
std::vector<Tensor<DataType>> v_g_n_os;
std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns;
std::vector<Tensor<DataType>> y_g_m_os;
std::vector<Tensor<DataType>> q_tensors; std::vector<Tensor<DataType>> q_tensors;
std::vector<Tensor<DataType>> k_tensors; std::vector<Tensor<DataType>> k_tensors;
std::vector<Tensor<DataType>> v_tensors; std::vector<Tensor<DataType>> v_tensors;
...@@ -478,14 +484,20 @@ int run(int argc, char* argv[]) ...@@ -478,14 +484,20 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); }); [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m); run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
y_gs_ms_os.ForEach( y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); }); [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); }); [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
q_g_m_ks.push_back(q_g_m_k);
k_g_n_ks.push_back(k_g_n_k);
v_g_n_os.push_back(v_g_n_o);
s_g_m_ns.push_back(s_g_m_n);
p_g_m_ns.push_back(p_g_m_n);
y_g_m_os.push_back(y_g_m_o);
q_tensors.push_back(q_gs_ms_ks); q_tensors.push_back(q_gs_ms_ks);
k_tensors.push_back(k_gs_ns_ks); k_tensors.push_back(k_gs_ns_ks);
v_tensors.push_back(v_gs_os_ns); v_tensors.push_back(v_gs_os_ns);
...@@ -566,172 +578,119 @@ int run(int argc, char* argv[]) ...@@ -566,172 +578,119 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
// bool pass = true; bool pass = true;
// if(do_verification) if(do_verification)
// { {
// kgrad_device_buf.SetZero(); // reset global accum buffer and rerun for(int i=0;i<group_count;i++){
// vgrad_device_buf.SetZero(); qgrad_tensors_device[i]->SetZero();
// invoker.Run(argument, StreamConfig{nullptr, false}); kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
// Tensor<DataType> qgrad_g_m_k({BatchCount, M, K}); }
// Tensor<DataType> kgrad_g_n_k({BatchCount, N, K}); invoker.Run(argument, StreamConfig{nullptr, false});
// Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); for(std::size_t i=0; i<group_count; i++){
// Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
// Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); int G0 = v_tensors[i].GetLengths()[0];
// Tensor<DataType> ygrad_g_m_o({BatchCount, M, O}); int G1 = v_tensors[i].GetLengths()[1];
// Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M}); int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3];
// ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { int M = q_tensors[i].GetLengths()[2];
// ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); int K = q_tensors[i].GetLengths()[3];
// }); int BatchCount = G0 * G1;
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
// #if PRINT_HOST Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
// { Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
// std::cout << "q_g_m_k ref:\n" << q_g_m_k; Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
// std::cout << "k_g_n_k ref:\n" << k_g_n_k; Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
// std::cout << "v_g_n_o ref:\n" << v_g_n_o; Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
// } ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
// #endif ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// // Gradients auto ref_gemm_grad = ReferenceGemmGradInstance{};
// auto ref_gemm_grad = ReferenceGemmGradInstance{}; auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
// auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// using RefGemmGradArg = ReferenceGemmGradInstance::Argument; // dP = dY * V^T
auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1});
// // dP = dY * V^T ref_gemm_grad_invoker.Run(RefGemmGradArg{
// auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1}); ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{ sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
// ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); float ygrad_dot_y = 0;
// #if PRINT_HOST for(int o = 0; o < O; o++)
// { {
// std::cout << "===== dP = dY * V^T\n"; auto idx_gmo = idx_gmn;
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; idx_gmo[2] = o;
// std::cout << "v_g_o_n ref:\n" << v_g_o_n; ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_os[i](idx_gmo);
// std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n; }
// } self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
// #endif });
auto p_g_n_m = p_g_m_ns[i].Transpose({0, 2, 1});
// // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) ref_gemm_grad_invoker.Run(RefGemmGradArg{
// sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
// float ygrad_dot_y = 0; ref_gemm_grad_invoker.Run(RefGemmGradArg{
// for(int o = 0; o < O; o++) sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
// { auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
// auto idx_gmo = idx_gmn; ref_gemm_grad_invoker.Run(RefGemmGradArg{
// idx_gmo[2] = o; sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
// ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
// } Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides());
// self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides());
// }); Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
// #if PRINT_HOST
// { Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides());
// std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n"; Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides());
// std::cout << "p_g_m_n ref:\n" << p_g_m_n; Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
// std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
// std::cout << "y_g_m_o ref:\n" << y_g_m_o; qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
// std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n; vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// } // permute
// #endif qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
// // dV = P^T * dY const size_t& g1 = idx[1];
// auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{ const size_t g = g0 * G1 + g1;
// p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
// #if PRINT_HOST self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
// { });
// std::cout << "===== dV = P^T * dY\n"; kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
// std::cout << "p_g_n_m ref:\n" << p_g_n_m; const size_t& g0 = idx[0];
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; const size_t& g1 = idx[1];
// std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
// } const size_t g = g0 * G1 + g1;
// #endif
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
// // dQ = alpha * dS * K });
// ref_gemm_grad_invoker.Run(RefGemmGradArg{ vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
// sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); const size_t& g0 = idx[0];
// #if PRINT_HOST const size_t& g1 = idx[1];
// {
// std::cout << "===== dQ = alpha * dS * K\n"; const size_t g = g0 * G1 + g1;
// std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
// std::cout << "k_g_n_k ref:\n" << k_g_n_k; self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
// std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k; });
// }
// #endif std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
// // dK = alpha * dS^T * Q qgrad_gs_ms_ks_host_result.mData,
// auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); "error",
// ref_gemm_grad_invoker.Run(RefGemmGradArg{ 1e-2,
// sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); 1e-2);
// #if PRINT_HOST std::cout << "Checking kgrad:\n";
// { pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
// std::cout << "===== dK = alpha * dS^T * Q\n"; kgrad_gs_ns_ks_host_result.mData,
// std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m; "error",
// std::cout << "q_g_m_k ref:\n" << q_g_m_k; 1e-2,
// std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k; 1e-2);
// } std::cout << "Checking vgrad:\n";
// #endif pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData,
// Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); "error",
// Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); 1e-2,
// Tensor<DataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); 1e-2);
}
// Tensor<DataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); }
// Tensor<DataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
// Tensor<DataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
// qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
// kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
// vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
// // permute
// qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
// });
// kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
// });
// vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
// });
// std::cout << "Checking qgrad:\n";
// pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
// qgrad_gs_ms_ks_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// std::cout << "Checking kgrad:\n";
// pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
// kgrad_gs_ns_ks_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// std::cout << "Checking vgrad:\n";
// pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
// vgrad_gs_os_ns_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// }
// return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
return 0;
} }
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -50,7 +50,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ...@@ -50,7 +50,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
index_t left = 0; index_t left = 0;
...@@ -718,9 +718,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -718,9 +718,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]); const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]);
const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]); const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]);
const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]); auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -844,31 +844,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -844,31 +844,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
// ignore = acc1_biases_gs_ms_gemm1ns_strides; // ignore = acc1_biases_gs_ms_gemm1ns_strides;
} }
// void Print() const
// {
// std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
// << a_grid_desc_g_m_k_.GetLength(I1) << ", "
// << a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// // a_grid_desc_g_m_k_.Print();
// std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
// << b_grid_desc_g_n_k_.GetLength(I1) << ", "
// << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// // b_grid_desc_g_n_k_.Print();
// std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
// << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
// << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// // b1_grid_desc_g_n_k_.Print();
// std::cout << "c_grid_desc_g_m_o_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
// << c_grid_desc_g_m_n_.GetLength(I1) << ", "
// << c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// // c_grid_desc_g_m_n_.Print();
// std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", "
// << vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
// std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
// << ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
// << ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
// }
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -914,15 +889,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -914,15 +889,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1< const auto kernel =
GridwiseGemm, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<GridwiseGemm,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment