"vscode:/vscode.git/clone" did not exist on "89ed418b8366066a8d54d4c2ae98454919fbc207"
Commit 65c56e56 authored by Chao Liu's avatar Chao Liu
Browse files

update Tensor

parent 028171e9
...@@ -114,9 +114,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -114,9 +114,9 @@ bool profile_batched_gemm_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
......
...@@ -193,13 +193,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -193,13 +193,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
} }
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
d0_g_m_device_result.mDesc.GetElementSpace()); d0_g_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace()); d1_g_m_device_result.mDesc.GetElementSpaceSize());
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(), std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()}; reduce1_device_buf.GetDeviceBuffer()};
......
...@@ -93,9 +93,9 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -93,9 +93,9 @@ bool profile_conv_bwd_data_impl(int do_verification,
weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpaceSize());
out_device_buf.ToDevice(output.mData.data()); out_device_buf.ToDevice(output.mData.data());
wei_device_buf.ToDevice(weight.mData.data()); wei_device_buf.ToDevice(weight.mData.data());
......
...@@ -99,9 +99,10 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -99,9 +99,10 @@ bool profile_conv_bwd_weight_impl(int do_verification,
output.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5}); output.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
} }
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight_device_result.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) *
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); weight_device_result.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(input.mData.data()); in_device_buf.ToDevice(input.mData.data());
out_device_buf.ToDevice(output.mData.data()); out_device_buf.ToDevice(output.mData.data());
......
...@@ -157,12 +157,12 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ...@@ -157,12 +157,12 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); out_n_k_ho_wo_device_result.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpaceSize());
DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpace()); DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
......
...@@ -149,11 +149,11 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -149,11 +149,11 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); out_n_k_ho_wo_device_result.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
......
...@@ -71,9 +71,9 @@ bool profile_conv_fwd_impl(int do_verification, ...@@ -71,9 +71,9 @@ bool profile_conv_fwd_impl(int do_verification,
const auto wei_element_op = WeiElementOp{}; const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{}; const auto out_element_op = OutElementOp{};
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(input.mData.data()); in_device_buf.ToDevice(input.mData.data());
wei_device_buf.ToDevice(weight.mData.data()); wei_device_buf.ToDevice(weight.mData.data());
......
...@@ -146,11 +146,11 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -146,11 +146,11 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
} }
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace()); DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -217,15 +217,15 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -217,15 +217,15 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
} }
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace()); DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace()); reduce0_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace()); reduce1_m_device_result.mDesc.GetElementSpaceSize());
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(), std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()}; reduce1_device_buf.GetDeviceBuffer()};
......
...@@ -142,10 +142,10 @@ bool profile_gemm_bilinear_impl(int do_verification, ...@@ -142,10 +142,10 @@ bool profile_gemm_bilinear_impl(int do_verification,
} }
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -86,9 +86,9 @@ int profile_gemm_impl(int do_verification, ...@@ -86,9 +86,9 @@ int profile_gemm_impl(int do_verification,
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{}; const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -189,13 +189,13 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -189,13 +189,13 @@ bool profile_gemm_reduce_impl(int do_verification,
} }
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) * DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace()); reduce0_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) * DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace()); reduce1_m_device_result.mDesc.GetElementSpaceSize());
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(), std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()}; reduce1_device_buf.GetDeviceBuffer()};
......
...@@ -87,9 +87,9 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -87,9 +87,9 @@ bool profile_gemm_splitk_impl(int do_verification,
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{}; const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -152,12 +152,12 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -152,12 +152,12 @@ void profile_grouped_gemm_impl(int do_verification,
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
a_device_buf.emplace_back( a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace())); std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back( b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace())); std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>( c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace())); sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
......
...@@ -92,8 +92,8 @@ void profile_normalization_impl(int do_verification, ...@@ -92,8 +92,8 @@ void profile_normalization_impl(int do_verification,
Tensor<OutDataType> out_ref(out); Tensor<OutDataType> out_ref(out);
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data()); in_dev.ToDevice(in.mData.data());
out_dev.ToDevice(out.mData.data()); out_dev.ToDevice(out.mData.data());
......
...@@ -245,13 +245,13 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -245,13 +245,13 @@ bool profile_reduce_impl_impl(bool do_verification,
} }
if(beta != 0.0f) if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
out.mData[i] = out_ref.mData[i]; out.mData[i] = out_ref.mData[i];
}; };
// these buffers are usually provided by the user application // these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data()); in_dev.ToDevice(in.mData.data());
......
...@@ -71,9 +71,9 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -71,9 +71,9 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
auto invoker_ptr = gemmPtr->MakeInvokerPointer(); auto invoker_ptr = gemmPtr->MakeInvokerPointer();
auto argument_ptr = auto argument_ptr =
......
...@@ -127,9 +127,9 @@ int test_gemm(const gemmArgs& args) ...@@ -127,9 +127,9 @@ int test_gemm(const gemmArgs& args)
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
......
...@@ -102,10 +102,10 @@ class TestLayernorm : public ::testing::Test ...@@ -102,10 +102,10 @@ class TestLayernorm : public ::testing::Test
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0}); beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpace()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpace()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpace()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
......
...@@ -80,8 +80,8 @@ class TestSoftmax : public ::testing::Test ...@@ -80,8 +80,8 @@ class TestSoftmax : public ::testing::Test
Tensor<OutDataType> out_ref(out); Tensor<OutDataType> out_ref(out);
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data()); in_dev.ToDevice(in.mData.data());
out_dev.ToDevice(out.mData.data()); out_dev.ToDevice(out.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment