Commit 85fb5d15 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use proper function to get tensor size.

parent 66c70dfe
...@@ -173,18 +173,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -173,18 +173,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
a_tensors_device.emplace_back(std::make_unique<DeviceMem>( a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * problem_size.Ms[i] * problem_size.Ks[i])); a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>( b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); c_device_tensors[i].mDesc.GetElementSpaceSize() * sizeof(EDataType)));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
c_tensors_device[i]->SetZero(); c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment