Commit a2bb4757 authored by fsx950223's avatar fsx950223
Browse files

fix example

parent 5509e684
......@@ -287,15 +287,15 @@ int run(int argc, char* argv[])
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<const DataType*> p_q;
std::vector<const DataType*> p_k;
std::vector<const DataType*> p_v;
std::vector<const DataType*> p_y;
std::vector<const LSEDataType*> p_lse;
std::vector<DataType*> p_qgrad;
std::vector<DataType*> p_kgrad;
std::vector<DataType*> p_vgrad;
std::vector<const DataType*> p_ygrad;
std::vector<const void*> p_q;
std::vector<const void*> p_k;
std::vector<const void*> p_v;
std::vector<const void*> p_y;
std::vector<const void*> p_lse;
std::vector<void*> p_qgrad;
std::vector<void*> p_kgrad;
std::vector<void*> p_vgrad;
std::vector<const void*> p_ygrad;
std::vector<Tensor<DataType>> q_g_m_ks;
std::vector<Tensor<DataType>> k_g_n_ks;
......@@ -517,15 +517,15 @@ int run(int argc, char* argv[])
kgrad_tensors_device.back()->SetZero();
vgrad_tensors_device.back()->SetZero();
ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data());
p_q.push_back(static_cast<DataType*>(q_tensors_device.back()->GetDeviceBuffer()));
p_k.push_back(static_cast<DataType*>(k_tensors_device.back()->GetDeviceBuffer()));
p_v.push_back(static_cast<DataType*>(v_tensors_device.back()->GetDeviceBuffer()));
p_y.push_back(static_cast<DataType*>(y_tensors_device.back()->GetDeviceBuffer()));
p_lse.push_back(static_cast<LSEDataType*>(lse_tensors_device.back()->GetDeviceBuffer()));
p_kgrad.push_back(static_cast<DataType*>(kgrad_tensors_device.back()->GetDeviceBuffer()));
p_vgrad.push_back(static_cast<DataType*>(vgrad_tensors_device.back()->GetDeviceBuffer()));
p_ygrad.push_back(static_cast<DataType*>(ygrad_tensors_device.back()->GetDeviceBuffer()));
p_qgrad.push_back(static_cast<DataType*>(qgrad_tensors_device.back()->GetDeviceBuffer()));
p_q.push_back(q_tensors_device.back()->GetDeviceBuffer());
p_k.push_back(k_tensors_device.back()->GetDeviceBuffer());
p_v.push_back(v_tensors_device.back()->GetDeviceBuffer());
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer());
p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer());
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
}
auto argument = gemm.MakeArgument(
p_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment