Commit 65c56e56 authored by Chao Liu's avatar Chao Liu
Browse files

update Tensor

parent 028171e9
......@@ -114,9 +114,9 @@ bool profile_batched_gemm_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
......
......@@ -193,13 +193,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
d0_g_m_device_result.mDesc.GetElementSpace());
d0_g_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace());
d1_g_m_device_result.mDesc.GetElementSpaceSize());
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
......
......@@ -93,9 +93,9 @@ bool profile_conv_bwd_data_impl(int do_verification,
weight.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * input_device_result.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpaceSize());
out_device_buf.ToDevice(output.mData.data());
wei_device_buf.ToDevice(weight.mData.data());
......
......@@ -99,9 +99,10 @@ bool profile_conv_bwd_weight_impl(int do_verification,
output.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight_device_result.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) *
weight_device_result.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(input.mData.data());
out_device_buf.ToDevice(output.mData.data());
......
......@@ -157,12 +157,12 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_device_result.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace());
DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpace());
out_n_k_ho_wo_device_result.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpaceSize());
DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
......
......@@ -149,11 +149,11 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) *
out_n_k_ho_wo_device_result.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace());
out_n_k_ho_wo_device_result.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
......
......@@ -71,9 +71,9 @@ bool profile_conv_fwd_impl(int do_verification,
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpace());
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace());
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(input.mData.data());
wei_device_buf.ToDevice(weight.mData.data());
......
......@@ -146,11 +146,11 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -217,15 +217,15 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace());
reduce0_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
reduce1_m_device_result.mDesc.GetElementSpaceSize());
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
......
......@@ -142,10 +142,10 @@ bool profile_gemm_bilinear_impl(int do_verification,
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -86,9 +86,9 @@ int profile_gemm_impl(int do_verification,
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -189,13 +189,13 @@ bool profile_gemm_reduce_impl(int do_verification,
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
reduce0_m_device_result.mDesc.GetElementSpace());
reduce0_m_device_result.mDesc.GetElementSpaceSize());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
reduce1_m_device_result.mDesc.GetElementSpaceSize());
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
reduce1_device_buf.GetDeviceBuffer()};
......
......@@ -87,9 +87,9 @@ bool profile_gemm_splitk_impl(int do_verification,
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -152,12 +152,12 @@ void profile_grouped_gemm_impl(int do_verification,
for(std::size_t i = 0; i < group_count; i++)
{
a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace()));
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace()));
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace()));
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
......
......@@ -92,8 +92,8 @@ void profile_normalization_impl(int do_verification,
Tensor<OutDataType> out_ref(out);
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data());
out_dev.ToDevice(out.mData.data());
......
......@@ -245,13 +245,13 @@ bool profile_reduce_impl_impl(bool do_verification,
}
if(beta != 0.0f)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
for(size_t i = 0; i < out_ref.mDesc.GetElementSpaceSize(); i++)
out.mData[i] = out_ref.mData[i];
};
// these buffers are usually provided by the user application
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data());
......
......@@ -71,9 +71,9 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
auto invoker_ptr = gemmPtr->MakeInvokerPointer();
auto argument_ptr =
......
......@@ -127,9 +127,9 @@ int test_gemm(const gemmArgs& args)
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......
......@@ -102,10 +102,10 @@ class TestLayernorm : public ::testing::Test
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpace());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpace());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpace());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpace());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
......
......@@ -80,8 +80,8 @@ class TestSoftmax : public ::testing::Test
Tensor<OutDataType> out_ref(out);
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
in_dev.ToDevice(in.mData.data());
out_dev.ToDevice(out.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment