Unverified Commit 12f4cfce authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

fixed alloc mem size (#145)

parent f95267f1
...@@ -169,11 +169,11 @@ int main(int argc, char* argv[]) ...@@ -169,11 +169,11 @@ int main(int argc, char* argv[])
for(int i = 0; i < gemm_shapes.size(); i++) for(int i = 0; i < gemm_shapes.size(); i++)
{ {
a_tensors_device.emplace_back( a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace()));
b_tensors_device.emplace_back( b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize())); std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize())); sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpace()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
......
...@@ -145,12 +145,12 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -145,12 +145,12 @@ void profile_grouped_gemm_impl(int do_verification,
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
a_device_buf.emplace_back( a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSize())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace()));
b_device_buf.emplace_back( b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSize())); std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>( c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSize())); sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment