Commit a2bb4757 authored by fsx950223's avatar fsx950223
Browse files

fix example

parent 5509e684
...@@ -287,15 +287,15 @@ int run(int argc, char* argv[]) ...@@ -287,15 +287,15 @@ int run(int argc, char* argv[])
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs; std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<const DataType*> p_q; std::vector<const void*> p_q;
std::vector<const DataType*> p_k; std::vector<const void*> p_k;
std::vector<const DataType*> p_v; std::vector<const void*> p_v;
std::vector<const DataType*> p_y; std::vector<const void*> p_y;
std::vector<const LSEDataType*> p_lse; std::vector<const void*> p_lse;
std::vector<DataType*> p_qgrad; std::vector<void*> p_qgrad;
std::vector<DataType*> p_kgrad; std::vector<void*> p_kgrad;
std::vector<DataType*> p_vgrad; std::vector<void*> p_vgrad;
std::vector<const DataType*> p_ygrad; std::vector<const void*> p_ygrad;
std::vector<Tensor<DataType>> q_g_m_ks; std::vector<Tensor<DataType>> q_g_m_ks;
std::vector<Tensor<DataType>> k_g_n_ks; std::vector<Tensor<DataType>> k_g_n_ks;
...@@ -517,15 +517,15 @@ int run(int argc, char* argv[]) ...@@ -517,15 +517,15 @@ int run(int argc, char* argv[])
kgrad_tensors_device.back()->SetZero(); kgrad_tensors_device.back()->SetZero();
vgrad_tensors_device.back()->SetZero(); vgrad_tensors_device.back()->SetZero();
ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data()); ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data());
p_q.push_back(static_cast<DataType*>(q_tensors_device.back()->GetDeviceBuffer())); p_q.push_back(q_tensors_device.back()->GetDeviceBuffer());
p_k.push_back(static_cast<DataType*>(k_tensors_device.back()->GetDeviceBuffer())); p_k.push_back(k_tensors_device.back()->GetDeviceBuffer());
p_v.push_back(static_cast<DataType*>(v_tensors_device.back()->GetDeviceBuffer())); p_v.push_back(v_tensors_device.back()->GetDeviceBuffer());
p_y.push_back(static_cast<DataType*>(y_tensors_device.back()->GetDeviceBuffer())); p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(static_cast<LSEDataType*>(lse_tensors_device.back()->GetDeviceBuffer())); p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
p_kgrad.push_back(static_cast<DataType*>(kgrad_tensors_device.back()->GetDeviceBuffer())); p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer());
p_vgrad.push_back(static_cast<DataType*>(vgrad_tensors_device.back()->GetDeviceBuffer())); p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer());
p_ygrad.push_back(static_cast<DataType*>(ygrad_tensors_device.back()->GetDeviceBuffer())); p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(static_cast<DataType*>(qgrad_tensors_device.back()->GetDeviceBuffer())); p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
} }
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
p_q, p_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment