"docs/source/en/conceptual/stable_diffusion.mdx" did not exist on "12fd0736dcc51f77c52130ab10177d0c1d5a29d9"
Commit 65c56e56 authored by Chao Liu's avatar Chao Liu
Browse files

update Tensor

parent 028171e9
...@@ -188,13 +188,13 @@ int main(int argc, char* argv[]) ...@@ -188,13 +188,13 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace()); reduce0_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace()); reduce1_m_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -135,9 +135,9 @@ int run_conv_bwd_data(bool do_verification, ...@@ -135,9 +135,9 @@ int run_conv_bwd_data(bool do_verification,
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
out_device_buf.ToDevice(out.mData.data()); out_device_buf.ToDevice(out.mData.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
......
...@@ -174,13 +174,13 @@ int main(int argc, char* argv[]) ...@@ -174,13 +174,13 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
d0_g_m_device_result.mDesc.GetElementSpace()); d0_g_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace()); d1_g_m_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
......
...@@ -92,9 +92,9 @@ int main() ...@@ -92,9 +92,9 @@ int main()
a_m_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); a_m_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); b_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpace()); DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpaceSize());
DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpace()); DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
a_m_n_device_buf.ToDevice(a_m_n.mData.data()); a_m_n_device_buf.ToDevice(a_m_n.mData.data());
b_n_device_buf.ToDevice(b_n.mData.data()); b_n_device_buf.ToDevice(b_n.mData.data());
......
...@@ -74,9 +74,9 @@ int main() ...@@ -74,9 +74,9 @@ int main()
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpaceSize());
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace()); DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace()); DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpaceSize());
a_m_device_buf.ToDevice(a_m.mData.data()); a_m_device_buf.ToDevice(a_m.mData.data());
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
......
...@@ -72,9 +72,9 @@ int main() ...@@ -72,9 +72,9 @@ int main()
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); b_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpaceSize());
DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace()); DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpaceSize());
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace()); DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpaceSize());
a_m_device_buf.ToDevice(a_m.mData.data()); a_m_device_buf.ToDevice(a_m.mData.data());
b_m_device_buf.ToDevice(b_m.mData.data()); b_m_device_buf.ToDevice(b_m.mData.data());
......
...@@ -74,9 +74,9 @@ int main() ...@@ -74,9 +74,9 @@ int main()
a.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); a.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
b.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0}); b.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
b_device_buf.ToDevice(b.mData.data()); b_device_buf.ToDevice(b.mData.data());
......
...@@ -136,9 +136,9 @@ int run_conv_bwd_weight(bool do_verification, ...@@ -136,9 +136,9 @@ int run_conv_bwd_weight(bool do_verification,
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5}); out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device_result.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_device_result.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data()); in_device_buf.ToDevice(in.mData.data());
out_device_buf.ToDevice(out.mData.data()); out_device_buf.ToDevice(out.mData.data());
......
...@@ -281,18 +281,19 @@ int main() ...@@ -281,18 +281,19 @@ int main()
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1}); gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1}); beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace()); DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace()); DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace()); DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) *
reduceMean_m.mDesc.GetElementSpaceSize());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace()); reduceMeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
layerNorm_m_n.mDesc.GetElementSpace()); layerNorm_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -249,16 +249,17 @@ int main() ...@@ -249,16 +249,17 @@ int main()
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1}); gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1}); beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace()); DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) *
reduceMean_m.mDesc.GetElementSpaceSize());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace()); reduceMeanSquare_m.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
layerNorm_m_n.mDesc.GetElementSpace()); layerNorm_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -185,13 +185,13 @@ int main(int argc, char* argv[]) ...@@ -185,13 +185,13 @@ int main(int argc, char* argv[])
c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0}); c_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<CDataType>{0});
acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}); acc_m_n_host_result.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpace()); DeviceMem c0_bias_buf(sizeof(C0DataType) * c0_n_bias.mDesc.GetElementSpaceSize());
DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpace()); DeviceMem c0_add_buf(sizeof(C0DataType) * c0_m_n_add.mDesc.GetElementSpaceSize());
DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpace()); DeviceMem c0_gamma_buf(sizeof(C0DataType) * c0_n_gamma.mDesc.GetElementSpaceSize());
DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpace()); DeviceMem c0_beta_buf(sizeof(C0DataType) * c0_n_beta.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -177,14 +177,14 @@ int main(int argc, char* argv[]) ...@@ -177,14 +177,14 @@ int main(int argc, char* argv[])
auto cgemm = DeviceCGemmInstance{}; auto cgemm = DeviceCGemmInstance{};
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace()); DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpaceSize());
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace()); DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace()); DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpace()); DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * DeviceMem c_m_n_real_device_buf(sizeof(CDataType) *
c_m_n_real_device_result.mDesc.GetElementSpace()); c_m_n_real_device_result.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
c_m_n_imag_device_result.mDesc.GetElementSpace()); c_m_n_imag_device_result.mDesc.GetElementSpaceSize());
DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC)); DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC));
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
......
...@@ -177,7 +177,7 @@ int main(int argc, char* argv[]) ...@@ -177,7 +177,7 @@ int main(int argc, char* argv[])
} }
if(beta != 0.0f) if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
out.mData[i] = out_ref.mData[i]; out.mData[i] = out_ref.mData[i];
}; };
// std::cout << "beta = " << beta << std::endl; // std::cout << "beta = " << beta << std::endl;
...@@ -185,8 +185,8 @@ int main(int argc, char* argv[]) ...@@ -185,8 +185,8 @@ int main(int argc, char* argv[])
// LogRangeAsType<float>(std::cout << "tensor prior out: " , out.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "tensor prior out: " , out.mData, ",") << std::endl;
// these buffers are usually provided by the user application // these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data()); in_dev.ToDevice(in.mData.data());
......
...@@ -154,9 +154,10 @@ int main(int argc, char* argv[]) ...@@ -154,9 +154,10 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) *
c_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
......
...@@ -186,12 +186,12 @@ int main(int argc, char* argv[]) ...@@ -186,12 +186,12 @@ int main(int argc, char* argv[])
d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0}); d_m0_m1_m2_n0_n1.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) * DeviceMem d_m0_m1_m2_n0_n1_device_buf(sizeof(DDataType) *
d_m0_m1_m2_n0_n1.mDesc.GetElementSpace()); d_m0_m1_m2_n0_n1.mDesc.GetElementSpaceSize());
DeviceMem e_m0_m1_m2_n0_n1_device_buf(sizeof(EDataType) * DeviceMem e_m0_m1_m2_n0_n1_device_buf(
e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpace()); sizeof(EDataType) * e_m0_m1_m2_n0_n1_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -324,10 +324,10 @@ int main(int argc, char* argv[]) ...@@ -324,10 +324,10 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpace()); DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_ms_ks.mData.data()); a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data()); b_device_buf.ToDevice(b_ns_ks.mData.data());
......
...@@ -307,9 +307,9 @@ int main(int argc, char* argv[]) ...@@ -307,9 +307,9 @@ int main(int argc, char* argv[])
break; break;
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_ms_ks.mData.data()); a_device_buf.ToDevice(a_ms_ks.mData.data());
b_device_buf.ToDevice(b_ns_ks.mData.data()); b_device_buf.ToDevice(b_ns_ks.mData.data());
......
...@@ -75,10 +75,10 @@ int main() ...@@ -75,10 +75,10 @@ int main()
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0}); beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpace()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpace()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpace()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
......
...@@ -91,7 +91,7 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -91,7 +91,7 @@ struct ReferenceConvFwd : public device::BaseOperator
arg.weight_.GetNumOfDimension() == NDimSpatial + 3 && arg.weight_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 3)) arg.output_.GetNumOfDimension() == NDimSpatial + 3))
{ {
std::throw("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
......
...@@ -103,8 +103,8 @@ class OpInstanceRunEngine ...@@ -103,8 +103,8 @@ class OpInstanceRunEngine
} }
} }
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{}); AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
out_device_buffer_ = out_device_buffer_ = std::make_unique<DeviceMem>(sizeof(OutDataType) *
std::make_unique<DeviceMem>(sizeof(OutDataType) * out_tensor_->mDesc.GetElementSpace()); out_tensor_->mDesc.GetElementSpaceSize());
out_device_buffer_->SetZero(); out_device_buffer_->SetZero();
} }
...@@ -222,7 +222,7 @@ class OpInstanceRunEngine ...@@ -222,7 +222,7 @@ class OpInstanceRunEngine
in_device_buffers_ in_device_buffers_
.emplace_back( .emplace_back(
std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) * std::make_unique<DeviceMem>(sizeof(std::tuple_element_t<Index, InArgsTypesTuple>) *
ts->mDesc.GetElementSpace())) ts->mDesc.GetElementSpaceSize()))
->ToDevice(ts->mData.data()); ->ToDevice(ts->mData.data());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment