Commit 65c56e56 authored by Chao Liu's avatar Chao Liu
Browse files

update Tensor

parent 028171e9
......@@ -188,13 +188,13 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace());
reduce0_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
reduce1_m_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -135,9 +135,9 @@ int run_conv_bwd_data(bool do_verification,
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
out_device_buf.ToDevice(out.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
......
......@@ -174,13 +174,13 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
d0_g_m_device_result.mDesc.GetElementSpace());
d0_g_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace());
d1_g_m_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
......
......@@ -92,9 +92,9 @@ int main()
a_m_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpace());
DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpaceSize());
DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
a_m_n_device_buf.ToDevice(a_m_n.mData.data());
b_n_device_buf.ToDevice(b_n.mData.data());
......
......@@ -74,9 +74,9 @@ int main()
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace());
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace());
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpaceSize());
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpaceSize());
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
......
......@@ -72,9 +72,9 @@ int main()
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace());
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace());
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpaceSize());
DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpaceSize());
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize());
a_m_device_buf.ToDevice(a_m.mData.data());
b_m_device_buf.ToDevice(b_m.mData.data());
......
......@@ -74,9 +74,9 @@ int main()
a.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
b_device_buf.ToDevice(b.mData.data());
......
......@@ -136,9 +136,9 @@ int run_conv_bwd_weight(bool do_verification,
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device_result.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
out_device_buf.ToDevice(out.mData.data());
......
......@@ -281,18 +281,19 @@ int main()
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) *
reduceMean_m.mDesc.GetElementSpaceSize());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
reduceMeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
layerNorm_m_n.mDesc.GetElementSpace());
layerNorm_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -249,16 +249,17 @@ int main()
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) *
reduceMean_m.mDesc.GetElementSpaceSize());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
reduceMeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
layerNorm_m_n.mDesc.GetElementSpace());
layerNorm_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -185,13 +185,13 @@ int main(int argc, char* argv[])
c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace());
DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpace());
DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace());
DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpaceSize());
DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpaceSize());
DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpaceSize());
DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -177,14 +177,14 @@ int main(int argc, char* argv[])
auto cgemm = DeviceCGemmInstance{};
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace());
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace());
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace());
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpace());
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpaceSize());
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) *
c_m_n_real_device_result.mDesc.GetElementSpace());
c_m_n_real_device_result.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
c_m_n_imag_device_result.mDesc.GetElementSpace());
c_m_n_imag_device_result.mDesc.GetElementSpaceSize());
DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC));
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
......
......@@ -177,7 +177,7 @@ int main(int argc, char* argv[])
}
if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
out.mData[i] = out_ref.mData[i];
};
// std::cout << "beta = " << beta << std::endl;
......@@ -185,8 +185,8 @@ int main(int argc, char* argv[])
// LogRangeAsType<float>(std::cout << "tensor prior out: " , out.mData, ",") << std::endl;
// these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data());
......
......@@ -154,9 +154,10 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
......
......@@ -186,12 +186,12 @@ int main(int argc, char* argv[])
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) *
d_m0_m1_m2_n0_n1.mDesc.GetElementSpace());
DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) *
e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace());
d_m0_m1_m2_n0_n1.mDesc.GetElementSpaceSize());
DeviceMem e_m0_m1_m2_n0_n1_device_buf(
sizeof(EDataType) * e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -324,10 +324,10 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
......
......@@ -307,9 +307,9 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data());
......
......@@ -75,10 +75,10 @@ int main()
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpace());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpace());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpace());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
......
......@@ -91,7 +91,7 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{
std::throw("wrong! inconsistent dimension");
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
......
......@@ -103,8 +103,8 @@ class OpInstanceRunEngine
}
}
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
out_device_buffer_ =
std::make_unique<DeviceMem>(sizeof(OutDataType) * out_tensor_->mDesc.GetElementSpace());
out_device_buffer_ = std::make_unique<DeviceMem>(sizeof(OutDataType) *
out_tensor_->mDesc.GetElementSpaceSize());
out_device_buffer_->SetZero();
}
......@@ -222,7 +222,7 @@ class OpInstanceRunEngine
in_device_buffers_
.emplace_back(
std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) *
ts->mDesc.GetElementSpace()))
ts->mDesc.GetElementSpaceSize()))
->ToDevice(ts->mData.data());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment