Commit 25d3a65b authored by guangzlu's avatar guangzlu
Browse files

fixed some bugs in example

parent e675b5c3
...@@ -157,6 +157,8 @@ int run(int argc, char* argv[]) ...@@ -157,6 +157,8 @@ int run(int argc, char* argv[])
<< std::endl; << std::endl;
} }
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
...@@ -185,7 +187,7 @@ int run(int argc, char* argv[]) ...@@ -185,7 +187,7 @@ int run(int argc, char* argv[])
b0_tensors.push_back(b0_gs_ns_ks); b0_tensors.push_back(b0_gs_ns_ks);
b1_tensors.push_back(b1_gs_os_ns); b1_tensors.push_back(b1_gs_os_ns);
c_tensors.push_back(c_gs_ms_os_device_result); c_tensors.push_back(c_gs_ms_os_device_result);
z_tensors.push_back(c_gs_ms_os_device_result); z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms_device_result); lse_tensors.push_back(lse_gs_ms_device_result);
a_tensors_device.emplace_back(std::make_unique<DeviceMem>( a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
...@@ -196,6 +198,8 @@ int run(int argc, char* argv[]) ...@@ -196,6 +198,8 @@ int run(int argc, char* argv[])
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize())); sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
z_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>( lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(LSEDataType) * lse_gs_ms_device_result.mDesc.GetElementSpaceSize())); sizeof(LSEDataType) * lse_gs_ms_device_result.mDesc.GetElementSpaceSize()));
...@@ -207,7 +211,7 @@ int run(int argc, char* argv[]) ...@@ -207,7 +211,7 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(nullptr); p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment