Commit 571e8728 authored by muozturk's avatar muozturk
Browse files

it was used to work but after merge there is some problem

parent 75cf3655
......@@ -3,6 +3,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "common_instances.hpp"
......
......@@ -15,7 +15,7 @@
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
int run_complex_contraction_bilinear_example(int argc, char* argv[])
int run_complex_contraction_scale_example(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
......@@ -159,10 +159,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// // Intermediate Value For E Real and Img
// DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
// DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// Intermediate Value For E Real and Img
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
......@@ -175,25 +174,23 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// set zero
e_device_buf_re.SetZero();
e_device_buf_img.SetZero();
e_device_buf_re1.SetZero();
e_device_buf_img1.SetZero();
// // set zero for intermediate values
// e_device_buf_re1.SetZero();
// e_device_buf_img1.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op_scale = CDEElementOp_Scale{scale};
// device operation
// C_real = A_real * B_real
// E1_real = A_real * B_real
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
auto op_scale = DeviceOpInstance_Scale{};
auto invoker_scale = op_scale.MakeInvoker();
auto argument_re1 = op_scale.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_re.GetDeviceBuffer(),
// std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
std::array<const void*, 0>{},
e_device_buf_re.GetDeviceBuffer(),
e_device_buf_re1.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
......@@ -206,29 +203,32 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op,
cde_element_op_scale);
if(!op.IsSupportedArgument(argument_re1))
if(!op_scale.IsSupportedArgument(argument_re1))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
std::cout << op_scale.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
float ave_time_re1 = invoker_scale.Run(argument_re1, StreamConfig{nullptr, time_kernel});
alpha = -1.f * scale;
beta = 1.f;
a_element_op = AElementOp{};
b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// device operation
// For real Intermediate Value re_2
// For real Intermediate Value
// E_real = E1_real + A_img * B_img
auto op = DeviceOpInstance{} ;
auto invoker = op.MakeInvoker();
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re.GetDeviceBuffer()},
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
......@@ -252,15 +252,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
// scale = 1.f ;
// a_element_op = AElementOp{};
// b_element_op = BElementOp{};
// cde_element_op = CDEElementOp{alpha, beta};
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
auto argument_img1 = op_scale.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 0>{}},
std::array<const void*, 0>{},
e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
......@@ -275,18 +269,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
cde_element_op_scale);
if(!op.IsSupportedArgument(argument_img1))
if(!op_scale.IsSupportedArgument(argument_img1))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
std::cout << op_scale.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time_img1 = invoker.Run(argument_img1, StreamConfig{nullptr, time_kernel});
float ave_time_img1 = invoker_scale.Run(argument_img1, StreamConfig{nullptr, time_kernel});
alpha = 1.f * scale;
beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img.GetDeviceBuffer()},
......@@ -325,8 +321,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K * 2;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N) * 2;
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment