Commit 15146ed6 authored by muozturk's avatar muozturk
Browse files

try to fix

parent 1e5e2dc3
...@@ -110,9 +110,9 @@ int run_complex_contraction_scale_example(int argc, char* argv[]) ...@@ -110,9 +110,9 @@ int run_complex_contraction_scale_example(int argc, char* argv[])
Tensor<EDataType> e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); Tensor<EDataType> e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides); Tensor<EDataType> e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides);
// // Intermediate E tensor Definition // Intermediate E tensor Definition
// Tensor<EDataType> e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides); Tensor<EDataType> e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides);
// Tensor<EDataType> e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides); Tensor<EDataType> e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl; std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl;
std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl; std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl;
...@@ -150,25 +150,15 @@ int run_complex_contraction_scale_example(int argc, char* argv[]) ...@@ -150,25 +150,15 @@ int run_complex_contraction_scale_example(int argc, char* argv[])
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re(sizeof(EDataType) * DeviceMem e_device_buf_re(sizeof(EDataType) *e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_img(sizeof(EDataType) *e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
<<<<<<< HEAD
// Intermediate Value For E Real and Img // Intermediate Value For E Real and Img
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
=======
DeviceMem e_device_buf_img(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// // Intermediate Value For E Real and Img
// DeviceMem e_device_buf_re1(sizeof(EDataType) *
// e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem
// e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
...@@ -183,23 +173,13 @@ int run_complex_contraction_scale_example(int argc, char* argv[]) ...@@ -183,23 +173,13 @@ int run_complex_contraction_scale_example(int argc, char* argv[])
e_device_buf_img1.SetZero(); e_device_buf_img1.SetZero();
<<<<<<< HEAD
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
=======
// // set zero for intermediate values
// e_device_buf_re1.SetZero();
// e_device_buf_img1.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
auto cde_element_op_scale = CDEElementOp_Scale{scale}; auto cde_element_op_scale = CDEElementOp_Scale{scale};
// device operation // device operation
// E1_real = A_real * B_real // E1_real1 = A_real * B_real
<<<<<<< HEAD
auto op_scale = DeviceOpInstance_Scale{}; auto op_scale = DeviceOpInstance_Scale{};
auto invoker_scale = op_scale.MakeInvoker(); auto invoker_scale = op_scale.MakeInvoker();
auto argument_re1 = op_scale.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_re1 = op_scale.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
...@@ -217,32 +197,10 @@ int run_complex_contraction_scale_example(int argc, char* argv[]) ...@@ -217,32 +197,10 @@ int run_complex_contraction_scale_example(int argc, char* argv[])
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op_scale); cde_element_op_scale);
=======
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument_re1 =
op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_re.GetDeviceBuffer(),
// std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
std::array<const void*, 0>{},
e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 0>{},
std::array<std::vector<ck::index_t>, 0>{},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op_scale);
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
if(!op_scale.IsSupportedArgument(argument_re1)) if(!op_scale.IsSupportedArgument(argument_re1))
{ {
std::cout << op_scale.GetTypeString() << " does not support this problem" << std::endl; std::cout << op_scale.GetTypeString() << " does not support this problem" << std::endl;
return 0; return 0;
} }
...@@ -251,13 +209,12 @@ int run_complex_contraction_scale_example(int argc, char* argv[]) ...@@ -251,13 +209,12 @@ int run_complex_contraction_scale_example(int argc, char* argv[])
alpha = -1.f * scale; alpha = -1.f * scale;
beta = 1.f; beta = 1.f;
<<<<<<< HEAD
auto cde_element_op = CDEElementOp{alpha, beta}; auto cde_element_op = CDEElementOp{alpha, beta};
// For real Intermediate Value // For real Intermediate Value
// E_real = E1_real + A_img * B_img // E_real = beta * E1_real + alpha * A_img * B_img
auto op = DeviceOpInstance{} ; auto op = DeviceOpInstance{} ;
...@@ -277,59 +234,22 @@ int run_complex_contraction_scale_example(int argc, char* argv[]) ...@@ -277,59 +234,22 @@ int run_complex_contraction_scale_example(int argc, char* argv[])
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
=======
a_element_op = AElementOp{};
b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// device operation
// For real Intermediate Value re_2
auto argument_re2 =
op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re.GetDeviceBuffer()},
e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
if(!op.IsSupportedArgument(argument_re2)) if(!op.IsSupportedArgument(argument_re2))
{ {
std::cout << op.GetTypeString() << " does not support this problem" << std::endl; std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0; return 0;
} }
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
<<<<<<< HEAD // For real Intermediate Value
// E_img1 = A_re * B_img ( SCALE )
auto argument_img1 = op_scale.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_img1 = op_scale.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(), b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 0>{}, std::array<const void*, 0>{},
======= e_device_buf_img1.GetDeviceBuffer(),
// scale = 1.f ;
// a_element_op = AElementOp{};
// b_element_op = BElementOp{};
// cde_element_op = CDEElementOp{alpha, beta};
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 0>{}
},
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
...@@ -342,25 +262,20 @@ int run_complex_contraction_scale_example(int argc, char* argv[]) ...@@ -342,25 +262,20 @@ int run_complex_contraction_scale_example(int argc, char* argv[])
b_element_op, b_element_op,
cde_element_op_scale); cde_element_op_scale);
if(!op.IsSupportedArgument(argument_img1))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
<<<<<<< HEAD
if(!op_scale.IsSupportedArgument(argument_img1)) if(!op_scale.IsSupportedArgument(argument_img1))
{ {
std::cout << op_scale.GetTypeString() << " does not support this problem" << std::endl; std::cout << op_scale.GetTypeString() << " does not support this problem" << std::endl;
=======
return 0; return 0;
}
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
float ave_time_img1 = invoker.Run(argument_img1, StreamConfig{nullptr, time_kernel}); }
alpha = 1.f * scale; float ave_time_img1 = invoker_scale.Run(argument_img1, StreamConfig{nullptr, time_kernel});
beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), alpha = 1.f * scale;
beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_re.GetDeviceBuffer(), b_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img.GetDeviceBuffer()}, std::array<const void*, 1>{e_device_buf_img.GetDeviceBuffer()},
e_device_buf_img.GetDeviceBuffer(), e_device_buf_img.GetDeviceBuffer(),
...@@ -376,78 +291,46 @@ auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), ...@@ -376,78 +291,46 @@ auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_element_op, b_element_op,
cde_element_op); cde_element_op);
if(!op.IsSupportedArgument(argument_img2)) if(!op.IsSupportedArgument(argument_img2))
{ {
std::cout << op.GetTypeString() << " does not support this problem" << std::endl; std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
<<<<<<< HEAD
float ave_time_img1 = invoker_scale.Run(argument_img1, StreamConfig{nullptr, time_kernel});
=======
return 0; return 0;
}
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); }
<<<<<<< HEAD
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
b_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img.GetDeviceBuffer()},
e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
=======
ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
ck::index_t N = ck::accumulate_n<ck::index_t>( ck::index_t M =ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>( ck::index_t N = ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K * 2; ck::index_t K = ck::accumulate_n<ck::index_t>(a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1; std::size_t flop = std::size_t(2) * M * N * K * 2;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N )* 2;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
<< op.GetTypeString() << std::endl; float gb_per_sec = num_btype / 1.E6 / ave_time;
e_device_buf_re.FromDevice(e_ms_ns_device_result_re.mData.data()); std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "<< op.GetTypeString() << std::endl;
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
auto isRealOk = 0; e_device_buf_re.FromDevice(e_ms_ns_device_result_re.mData.data());
auto isImgOk = 0; e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
if(do_verification) auto isRealOk = 0;
{ auto isImgOk = 0;
if(do_verification)
{
// Real Part Verification // Real Part Verification
Tensor<CShuffleDataType> c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<CShuffleDataType> c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides);
<<<<<<< HEAD using ReferenceOpInstance =ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
std::size_t flop = std::size_t(2) * M * N * K * 2;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N) * 2;
=======
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
...@@ -457,13 +340,11 @@ if(do_verification) ...@@ -457,13 +340,11 @@ if(do_verification)
F32, F32,
AElementOp, AElementOp,
BElementOp>; BElementOp>;
>>>>>>> 81eece66cb698622dafbc0f2c726ba743eb345fe
auto ref_op = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument_re = ref_op.MakeArgument( auto ref_argument_re = ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re); ref_invoker.Run(ref_argument_re);
...@@ -571,7 +452,7 @@ if(do_verification) ...@@ -571,7 +452,7 @@ if(do_verification)
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
return (isRealOk && isImgOk); return (isRealOk && isImgOk);
} }
return 0; return 0;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment