Commit 65c56e56 authored by Chao Liu's avatar Chao Liu
Browse files

update Tensor

parent 028171e9
......@@ -142,9 +142,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -141,9 +141,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -139,9 +139,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -170,9 +170,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -155,9 +155,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -165,9 +165,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -167,9 +167,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -213,10 +213,10 @@ int main(int argc, char* argv[])
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -191,10 +191,10 @@ int main(int argc, char* argv[])
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -156,11 +156,11 @@ int main(int argc, char* argv[])
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -235,10 +235,11 @@ int main(int argc, char* argv[])
bias_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_ho_wo_k_device.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_ho_wo_k_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data());
......
......@@ -138,10 +138,10 @@ int run_conv_fwd_bias(bool do_verification,
bias.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
......
......@@ -238,11 +238,12 @@ int main(int argc, char* argv[])
resi_n_ho_wo_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace());
DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_ho_wo_k.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_ho_wo_k_device.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in_n_hi_wi_c.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_y_x_c.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpaceSize());
DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_ho_wo_k.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_ho_wo_k_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_device_buf.ToDevice(wei_k_y_x_c.mData.data());
......
......@@ -146,10 +146,10 @@ int run_conv_fwd(bool do_verification,
bias.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
......
......@@ -230,13 +230,13 @@ int main(int argc, char* argv[])
}
if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
out.mData[i] = out_ref.mData[i];
};
// these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data());
......
......@@ -171,13 +171,13 @@ int main(int argc, char* argv[])
}
if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
out.mData[i] = out_ref.mData[i];
};
DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpace());
DeviceMem in_2_dev(sizeof(InOutDataType) * in_2.mDesc.GetElementSpace());
DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpace());
DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpaceSize());
DeviceMem in_2_dev(sizeof(InOutDataType) * in_2.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpaceSize());
in_1_dev.ToDevice(in_1.mData.data());
......
......@@ -204,10 +204,11 @@ bool pool_test(bool do_verification,
default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_c_ho_wo_device.mDesc.GetElementSpaceSize());
DeviceMem out_indices_device_buf(sizeof(IndexDataType) *
out_indices_n_c_ho_wo_device.mDesc.GetElementSpace());
out_indices_n_c_ho_wo_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
......
......@@ -190,9 +190,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -170,12 +170,12 @@ int main(int argc, char* argv[])
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace()));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpace()));
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize()));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpace()));
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
......
......@@ -173,11 +173,11 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce_device_buf(sizeof(ReduceDataType) *
reduce_m_device_result.mDesc.GetElementSpace());
reduce_m_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment