Commit 571e8728 authored by muozturk's avatar muozturk
Browse files

it was used to work but after merge there is some problem

parent 75cf3655
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "common_instances.hpp" #include "common_instances.hpp"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ck/library/utility/numeric.hpp" #include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
int run_complex_contraction_bilinear_example(int argc, char* argv[]) int run_complex_contraction_scale_example(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
...@@ -159,10 +159,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -159,10 +159,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// // Intermediate Value For E Real and Img // Intermediate Value For E Real and Img
// DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
// DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
...@@ -175,25 +174,23 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -175,25 +174,23 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// set zero // set zero
e_device_buf_re.SetZero(); e_device_buf_re.SetZero();
e_device_buf_img.SetZero(); e_device_buf_img.SetZero();
e_device_buf_re1.SetZero();
e_device_buf_img1.SetZero();
// // set zero for intermediate values
// e_device_buf_re1.SetZero();
// e_device_buf_img1.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op_scale = CDEElementOp_Scale{scale}; auto cde_element_op_scale = CDEElementOp_Scale{scale};
// device operation // device operation
// C_real = A_real * B_real // E1_real = A_real * B_real
auto op = DeviceOpInstance{}; auto op_scale = DeviceOpInstance_Scale{};
auto invoker = op.MakeInvoker(); auto invoker_scale = op_scale.MakeInvoker();
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_re1 = op_scale.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_re.GetDeviceBuffer(), b_device_buf_re.GetDeviceBuffer(),
// std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
std::array<const void*, 0>{}, std::array<const void*, 0>{},
e_device_buf_re.GetDeviceBuffer(), e_device_buf_re1.GetDeviceBuffer(),
a_ms_ks_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
...@@ -206,29 +203,32 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -206,29 +203,32 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op, b_element_op,
cde_element_op_scale); cde_element_op_scale);
if(!op.IsSupportedArgument(argument_re1)) if(!op_scale.IsSupportedArgument(argument_re1))
{ {
std::cout << op.GetTypeString() << " does not support this problem" << std::endl; std::cout << op_scale.GetTypeString() << " does not support this problem" << std::endl;
return 0; return 0;
} }
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); float ave_time_re1 = invoker_scale.Run(argument_re1, StreamConfig{nullptr, time_kernel});
alpha = -1.f * scale; alpha = -1.f * scale;
beta = 1.f; beta = 1.f;
a_element_op = AElementOp{};
b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta}; auto cde_element_op = CDEElementOp{alpha, beta};
// device operation
// For real Intermediate Value re_2
// For real Intermediate Value
// E_real = E1_real + A_img * B_img
auto op = DeviceOpInstance{} ;
auto invoker = op.MakeInvoker();
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(), b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re.GetDeviceBuffer()}, std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
e_device_buf_re.GetDeviceBuffer(), e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
...@@ -252,15 +252,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -252,15 +252,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
// scale = 1.f ; auto argument_img1 = op_scale.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
// a_element_op = AElementOp{};
// b_element_op = BElementOp{};
// cde_element_op = CDEElementOp{alpha, beta};
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(), b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 0>{}}, std::array<const void*, 0>{},
e_device_buf_img.GetDeviceBuffer(), e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
...@@ -275,18 +269,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -275,18 +269,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
cde_element_op_scale); cde_element_op_scale);
if(!op.IsSupportedArgument(argument_img1)) if(!op_scale.IsSupportedArgument(argument_img1))
{ {
std::cout << op.GetTypeString() << " does not support this problem" << std::endl; std::cout << op_scale.GetTypeString() << " does not support this problem" << std::endl;
return 0; return 0;
} }
float ave_time_img1 = invoker.Run(argument_img1, StreamConfig{nullptr, time_kernel}); float ave_time_img1 = invoker_scale.Run(argument_img1, StreamConfig{nullptr, time_kernel});
alpha = 1.f * scale; alpha = 1.f * scale;
beta = 1.f; beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_re.GetDeviceBuffer(), b_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img.GetDeviceBuffer()}, std::array<const void*, 1>{e_device_buf_img.GetDeviceBuffer()},
...@@ -325,8 +321,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -325,8 +321,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K * 2; std::size_t flop = std::size_t(2) * M * N * K * 2;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N) * 2;
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ; float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment