Commit c0ff964b authored by muozturk's avatar muozturk
Browse files

validation check in progress

parent 88c36f83
......@@ -18,7 +18,7 @@
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
#include "Complex.hpp"
// #include "Complex.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -182,9 +182,9 @@ int main(int argc, char* argv[])
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl;
std::cout << "a_ms_ks: " << a_ms_ks_re.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks_re.mDesc << std::endl;
std::cout << "d_ms_ns: " << d_ms_ns_re.mDesc << std::endl;
std::cout << "e_ms_ns: " << e_ms_ns_host_result_re.mDesc << std::endl;
switch(init_method)
......@@ -210,21 +210,21 @@ int main(int argc, char* argv[])
break;
}
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// Intermediate Value For E Real and Img
// LookAtHere
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
// DeviceMem e_device_buf_re2(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// DeviceMem e_device_buf_img2(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize());
// LookAtHere
......@@ -273,7 +273,7 @@ int main(int argc, char* argv[])
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
if(!op.IsSupportedArgument(argument_re1))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
......@@ -286,14 +286,14 @@ int main(int argc, char* argv[])
alpha = -1.f;
beta = 1.f;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
a_element_op = AElementOp{};
b_element_op = BElementOp{};
cde_element_op = CDEElementOp{alpha, beta};
// device operation
// For real Intermediate Value re_2
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
// auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker();
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
......@@ -310,7 +310,7 @@ int main(int argc, char* argv[])
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument))
if(!op.IsSupportedArgument(argument_re2))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
......@@ -408,12 +408,11 @@ int main(int argc, char* argv[])
e_device_buf_re.FromDevice(e_ms_ns_device_result_re.mData.data());
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
auto isRealOk = 0;
if(do_verification)
{
// Real Part Verification
Tensor<CShuffleDataType> c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<CShuffleDataType> c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides);
......@@ -425,6 +424,7 @@ int main(int argc, char* argv[])
BDataType,
CShuffleDataType,
AccDataType,
F32,
AElementOp,
BElementOp>;
......@@ -469,8 +469,8 @@ int main(int argc, char* argv[])
cde_element_op = CDEElementOp{alpha, beta};
bool isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
}
return 0;
return isRealOk;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment