"...composable_kernel.git" did not exist on "ecf337bab5c23708d80a4c537c6b49dbda6e23b2"
Commit 1e5e2dc3 authored by muozturk's avatar muozturk
Browse files

merge

parents 571e8728 81eece66
...@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: default:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
} }
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re(sizeof(EDataType) *
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_img(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// Intermediate Value For E Real and Img // Intermediate Value For E Real and Img
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re1(sizeof(EDataType) *
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
...@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// set zero for intermediate values // set zero for intermediate values
e_device_buf_re1.SetZero(); e_device_buf_re1.SetZero();
e_device_buf_img1.SetZero(); e_device_buf_img1.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta}; auto cde_element_op = CDEElementOp{alpha, beta};
...@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// device operation // device operation
// For real Intermediate Value re_1 // For real Intermediate Value re_1
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_re1 =
b_device_buf_re.GetDeviceBuffer(), op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()}, b_device_buf_re.GetDeviceBuffer(),
e_device_buf_re1.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_re1.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re1)) if(!op.IsSupportedArgument(argument_re1))
{ {
...@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
alpha = -1.f; alpha = -1.f;
beta = 1.f; beta = 1.f;
...@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// For real Intermediate Value re_2 // For real Intermediate Value re_2
// auto op = DeviceOpInstance{}; // auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker(); // auto invoker = op.MakeInvoker();
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_re2 =
b_device_buf_img.GetDeviceBuffer(), op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()}, b_device_buf_img.GetDeviceBuffer(),
e_device_buf_re.GetDeviceBuffer(), std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re2)) if(!op.IsSupportedArgument(argument_re2))
{ {
...@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
...@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op = BElementOp{}; b_element_op = BElementOp{};
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_img1 =
b_device_buf_img.GetDeviceBuffer(), op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()}, b_device_buf_img.GetDeviceBuffer(),
e_device_buf_img1.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_img1.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img1)) if(!op.IsSupportedArgument(argument_img1))
{ {
...@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_img2 =
b_device_buf_re.GetDeviceBuffer(), op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()}, b_device_buf_re.GetDeviceBuffer(),
e_device_buf_img.GetDeviceBuffer(), std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img2)) if(!op.IsSupportedArgument(argument_img2))
{ {
...@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
ck::index_t M = ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
...@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2; sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ; float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
...@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data()); e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
auto isRealOk = 0; auto isRealOk = 0;
auto isImgOk = 0; auto isImgOk = 0;
if(do_verification) if(do_verification)
{ {
...@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
auto ref_op = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument_re = auto ref_argument_re = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re); ref_invoker.Run(ref_argument_re);
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
...@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha = 1.f; alpha = 1.f;
beta = -1.f; beta = -1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
auto ref_argument_re1 = auto ref_argument_re1 = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re1); ref_invoker.Run(ref_argument_re1);
...@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
// Img Part Verification // Img Part Verification
Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
auto ref_argument_img = auto ref_argument_img = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img); ref_invoker.Run(ref_argument_img);
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
...@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
auto ref_argument_img1 = auto ref_argument_img1 = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img1); ref_invoker.Run(ref_argument_img1);
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
...@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
return (isRealOk && isImgOk); return (isRealOk && isImgOk);
} }
......
...@@ -19,9 +19,9 @@ static constexpr ck::index_t NumDimM = 2; ...@@ -19,9 +19,9 @@ static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2; static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2; static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
using CDEElementOp_Scale = ck::tensor_operation::element_wise::Scale; using CDEElementOp_Scale = ck::tensor_operation::element_wise::Scale;
using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM, using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
...@@ -39,18 +39,18 @@ using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM, ...@@ -39,18 +39,18 @@ using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
CDEElementOp>; CDEElementOp>;
using DeviceOpInstanceKKN_Scale = DeviceOpInstanceKK_Generic<NumDimM, using DeviceOpInstanceKKN_Scale = DeviceOpInstanceKK_Generic<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
EDataType, EDataType,
ComputeDataType, ComputeDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp_Scale>; CDEElementOp_Scale>;
using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM, using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM,
NumDimN, NumDimN,
...@@ -94,7 +94,7 @@ using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM, ...@@ -94,7 +94,7 @@ using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM,
BElementOp, BElementOp,
CDEElementOp>; CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKN; using DeviceOpInstance = DeviceOpInstanceKKN;
using DeviceOpInstance_Scale = DeviceOpInstanceKKN_Scale; using DeviceOpInstance_Scale = DeviceOpInstanceKKN_Scale;
#include "run_complex_contraction_scale_example.inc" #include "run_complex_contraction_scale_example.inc"
......
...@@ -18,9 +18,9 @@ static constexpr ck::index_t NumDimM = 2; ...@@ -18,9 +18,9 @@ static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2; static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2; static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
using CDEElementOp_Scale = ck::tensor_operation::element_wise::Scale; using CDEElementOp_Scale = ck::tensor_operation::element_wise::Scale;
using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM, using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
...@@ -38,18 +38,18 @@ using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM, ...@@ -38,18 +38,18 @@ using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic<NumDimM,
CDEElementOp>; CDEElementOp>;
using DeviceOpInstanceKKN_Scale = DeviceOpInstanceKK_Generic<NumDimM, using DeviceOpInstanceKKN_Scale = DeviceOpInstanceKK_Generic<NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
EDataType, EDataType,
ComputeDataType, ComputeDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CDEElementOp_Scale>; CDEElementOp_Scale>;
using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM, using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic<NumDimM,
NumDimN, NumDimN,
...@@ -93,7 +93,7 @@ using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM, ...@@ -93,7 +93,7 @@ using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic<NumDimM,
BElementOp, BElementOp,
CDEElementOp>; CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKN; using DeviceOpInstance = DeviceOpInstanceKKN;
using DeviceOpInstance_Scale = DeviceOpInstanceKKN_Scale; using DeviceOpInstance_Scale = DeviceOpInstanceKKN_Scale;
#include "run_complex_contraction_scale_example.inc" #include "run_complex_contraction_scale_example.inc"
......
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
...@@ -12,10 +12,8 @@ ...@@ -12,10 +12,8 @@
#include "profiler/profile_contraction_impl.hpp" #include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp" #include "profiler/profile_contraction_utils.hpp"
using F32 = float;
using F64 = double;
using F32 = float;
using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -93,7 +91,6 @@ class TestContraction : public ::testing::Test ...@@ -93,7 +91,6 @@ class TestContraction : public ::testing::Test
} }
}; };
template <typename Tuple> template <typename Tuple>
class TestContractionBilinear : public TestContraction<Tuple> class TestContractionBilinear : public TestContraction<Tuple>
{ {
...@@ -109,10 +106,8 @@ using BilinearKernelTypes = ...@@ -109,10 +106,8 @@ using BilinearKernelTypes =
::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear), ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<F32>, F32, Bilinear),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>; ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<F64>, F64, Bilinear)>;
TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
TYPED_TEST(TestContractionBilinear, bilinear) TYPED_TEST(TestContractionBilinear, bilinear)
{ {
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f); this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
...@@ -120,5 +115,3 @@ TYPED_TEST(TestContractionBilinear, bilinear) ...@@ -120,5 +115,3 @@ TYPED_TEST(TestContractionBilinear, bilinear)
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f); this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
this->Run(); this->Run();
} }
...@@ -12,14 +12,13 @@ ...@@ -12,14 +12,13 @@
#include "profiler/profile_contraction_impl.hpp" #include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp" #include "profiler/profile_contraction_utils.hpp"
using F32 = float; using F32 = float;
using F64 = double; using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using Scale = ck::tensor_operation::element_wise::Scale;
using Scale = ck::tensor_operation::element_wise::Scale;
struct Dimensions struct Dimensions
{ {
...@@ -96,24 +95,17 @@ class TestContractionScale : public TestContraction<Tuple> ...@@ -96,24 +95,17 @@ class TestContractionScale : public TestContraction<Tuple>
{ {
}; };
#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \ #define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \
std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \ std::tuple<Row, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \ std::tuple<Row, Col, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \ std::tuple<Col, Row, Row, dt, tuple_dt, compute_dt, op>, \
std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op> std::tuple<Col, Col, Row, dt, tuple_dt, compute_dt, op>
using ScaleKernelTypes = ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F32, Scale), using ScaleKernelTypes = ::testing::Types<ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, F32, Scale),
ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>; ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>;
TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
TYPED_TEST(TestContractionScale, scale) TYPED_TEST(TestContractionScale, scale)
{ {
this->p_cd_element_op = std::make_unique<Scale>(1.f); this->p_cd_element_op = std::make_unique<Scale>(1.f);
...@@ -121,7 +113,3 @@ TYPED_TEST(TestContractionScale, scale) ...@@ -121,7 +113,3 @@ TYPED_TEST(TestContractionScale, scale)
this->p_cd_element_op = std::make_unique<Scale>(0.5f); this->p_cd_element_op = std::make_unique<Scale>(0.5f);
this->Run(); this->Run();
} }
...@@ -15,9 +15,8 @@ ...@@ -15,9 +15,8 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -56,7 +55,7 @@ class ContractionInstanceWrapper ...@@ -56,7 +55,7 @@ class ContractionInstanceWrapper
auto argument = contraction.MakeArgument(nullptr, auto argument = contraction.MakeArgument(nullptr,
nullptr, nullptr,
std::array<const void*, 1>{nullptr}, std::array<const void*, 1>{nullptr},
// std::array<const void*, 0>{}, // std::array<const void*, 0>{},
nullptr, nullptr,
ADims, ADims,
AStrides, AStrides,
...@@ -64,8 +63,8 @@ class ContractionInstanceWrapper ...@@ -64,8 +63,8 @@ class ContractionInstanceWrapper
BStrides, BStrides,
std::array<std::vector<ck::index_t>, 1>{DDims}, std::array<std::vector<ck::index_t>, 1>{DDims},
std::array<std::vector<ck::index_t>, 1>{DStrides}, std::array<std::vector<ck::index_t>, 1>{DStrides},
// std::array<std::vector<ck::index_t>, 0>{}, // std::array<std::vector<ck::index_t>, 0>{},
// std::array<std::vector<ck::index_t>, 0>{}, // std::array<std::vector<ck::index_t>, 0>{},
EDims, EDims,
EStrides, EStrides,
Pass{}, Pass{},
...@@ -101,7 +100,8 @@ class ContractionDeviceOpWrapper ...@@ -101,7 +100,8 @@ class ContractionDeviceOpWrapper
{ {
bool supported = false; bool supported = false;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceOp>::GetInstances(); const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -192,7 +192,7 @@ TEST(TestContractionSupportedArgs, DEMemoryAccess) ...@@ -192,7 +192,7 @@ TEST(TestContractionSupportedArgs, DEMemoryAccess)
EXPECT_FALSE( EXPECT_FALSE(
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, InvalidStrides, Strides)); wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, InvalidStrides, Strides));
EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides)); EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
// Memory access to E // Memory access to E
EXPECT_FALSE( EXPECT_FALSE(
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, InvalidStrides)); wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, InvalidStrides));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment