Commit 951a52b2 authored by letaoqin's avatar letaoqin
Browse files

rcr change to rrr

parent 635b5904
...@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: default:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
} }
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re(sizeof(EDataType) *
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_img(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// Intermediate Value For E Real and Img // Intermediate Value For E Real and Img
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re1(sizeof(EDataType) *
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
...@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// set zero for intermediate values // set zero for intermediate values
e_device_buf_re1.SetZero(); e_device_buf_re1.SetZero();
e_device_buf_img1.SetZero(); e_device_buf_img1.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta}; auto cde_element_op = CDEElementOp{alpha, beta};
...@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// device operation // device operation
// For real Intermediate Value re_1 // For real Intermediate Value re_1
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_re1 =
b_device_buf_re.GetDeviceBuffer(), op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()}, b_device_buf_re.GetDeviceBuffer(),
e_device_buf_re1.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_re1.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re1)) if(!op.IsSupportedArgument(argument_re1))
{ {
...@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
alpha = -1.f; alpha = -1.f;
beta = 1.f; beta = 1.f;
...@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// For real Intermediate Value re_2 // For real Intermediate Value re_2
// auto op = DeviceOpInstance{}; // auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker(); // auto invoker = op.MakeInvoker();
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_re2 =
b_device_buf_img.GetDeviceBuffer(), op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()}, b_device_buf_img.GetDeviceBuffer(),
e_device_buf_re.GetDeviceBuffer(), std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re2)) if(!op.IsSupportedArgument(argument_re2))
{ {
...@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
...@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op = BElementOp{}; b_element_op = BElementOp{};
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_img1 =
b_device_buf_img.GetDeviceBuffer(), op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()}, b_device_buf_img.GetDeviceBuffer(),
e_device_buf_img1.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_img1.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img1)) if(!op.IsSupportedArgument(argument_img1))
{ {
...@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_img2 =
b_device_buf_re.GetDeviceBuffer(), op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()}, b_device_buf_re.GetDeviceBuffer(),
e_device_buf_img.GetDeviceBuffer(), std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img2)) if(!op.IsSupportedArgument(argument_img2))
{ {
...@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
ck::index_t M = ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
...@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2; sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ; float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
...@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data()); e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
auto isRealOk = 0; auto isRealOk = 0;
auto isImgOk = 0; auto isImgOk = 0;
if(do_verification) if(do_verification)
{ {
...@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
auto ref_op = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument_re = auto ref_argument_re = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re); ref_invoker.Run(ref_argument_re);
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
...@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha = 1.f; alpha = 1.f;
beta = -1.f; beta = -1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
auto ref_argument_re1 = auto ref_argument_re1 = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re1); ref_invoker.Run(ref_argument_re1);
...@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
// Img Part Verification // Img Part Verification
Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
auto ref_argument_img = auto ref_argument_img = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img); ref_invoker.Run(ref_argument_img);
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
...@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
auto ref_argument_img1 = auto ref_argument_img1 = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img1); ref_invoker.Run(ref_argument_img1);
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
...@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
return (isRealOk && isImgOk); return (isRealOk && isImgOk);
} }
......
...@@ -21,7 +21,7 @@ using AccDataType = F32; ...@@ -21,7 +21,7 @@ using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using D0Layout = Row; using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using CLayout = Row; using CLayout = Row;
...@@ -41,7 +41,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -41,7 +41,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
// clang-format off // clang-format off
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType> template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType>
using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
Row, Col, DsLayout, CLayout, ADataType, BDataType, ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType,
DsDataType, CDataType, AccDataType, CShuffleDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
64, 64,
...@@ -51,14 +51,14 @@ using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMul ...@@ -51,14 +51,14 @@ using DeviceOpInstance_64_16_16_64 = ck::tensor_operation::device::DeviceGemmMul
1, 1, 1, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>,
2, 8, 8, 0, 1, 2, 2, 0,
1, 1, 1, 1,
S<1, 16, 1, 4>, S<4, 4>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; S<1, 16, 1, 4>, S<4, 4>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType> template <typename ADataType, typename BDataType, typename DsDataType, typename CDataType>
using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3<
Row, Col, DsLayout, CLayout, ADataType, BDataType, ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType,
DsDataType, CDataType, AccDataType, CShuffleDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
64, 64,
...@@ -68,10 +68,10 @@ using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_ ...@@ -68,10 +68,10 @@ using DeviceOpInstance_default = ck::tensor_operation::device::DeviceGemmMultiD_
1, 1, 1, 1,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 1, 1, 0, 2, 1, 1, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>,
2, 1, 1, 0, 1, 1, 1, 0,
1, 1, 1, 1,
S<1, 16, 1, 4>, S<1, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>; S<1, 16, 1, 4>, S<2, 2>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16>;
// clang-format on // clang-format on
...@@ -97,7 +97,7 @@ float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config ...@@ -97,7 +97,7 @@ float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
ck::index_t StrideA = args.K; ck::index_t StrideA = args.K;
ck::index_t StrideB = args.K; ck::index_t StrideB = args.N;
ck::index_t StrideD = 0; ck::index_t StrideD = 0;
ck::index_t StrideC = args.N; ck::index_t StrideC = args.N;
...@@ -116,6 +116,7 @@ float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config ...@@ -116,6 +116,7 @@ float gemm_bias_add_fp16(const GemmBiasAddArgs& args, const StreamConfig& config
StrideB, StrideB,
std::array<ck::index_t, NumDTensor>{StrideD}, std::array<ck::index_t, NumDTensor>{StrideD},
StrideC, StrideC,
1,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
......
...@@ -32,36 +32,25 @@ using Row = ck::tensor_layout::gemm::RowMajor; ...@@ -32,36 +32,25 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0Layout = Row; using A0Layout = Row;
using B0Layout = Col; using B0Layout = Row;
using D0Layout = Row; using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>; using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row; using ELayout = Row;
void RunUnfusedTest(const std::vector<ck::half_t>& mat_A, using PassThrough = ck::tensor_operation::element_wise::PassThrough;
const std::vector<ck::half_t>& mat_B, // using Add = ck::tensor_operation::element_wise::Add;
const std::vector<ck::half_t>& mat_C,
std::vector<ck::half_t>& mat_D, using AElementOp = PassThrough;
int K, using BElementOp = PassThrough;
int M, using CElementOp = PassThrough;
int N)
{ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
for(int m = 0; m < M; m++) B0DataType,
{ EDataType,
std::vector<float> tmp; AccDataType,
for(int n = 0; n < N; n++) AElementOp,
{ BElementOp,
float psum = 0.f; CElementOp>;
for(int k = 0; k < K; k++)
{
float areg = float(mat_A[m * K + k]);
float breg = float(mat_B[n * K + k]);
psum += areg * breg;
}
psum += ck::type_convert<float>(mat_C[n]);
mat_D[m * N + n] = ck::type_convert<ck::half_t>(psum);
}
}
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -70,12 +59,12 @@ int main(int argc, char* argv[]) ...@@ -70,12 +59,12 @@ int main(int argc, char* argv[])
bool time_kernel = true; bool time_kernel = true;
// GEMM shape // GEMM shape
ck::index_t M = 512; ck::index_t M = 16;
ck::index_t N = 1024; ck::index_t N = 16;
ck::index_t K = 256; ck::index_t K = 64;
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = K; ck::index_t StrideB = N;
ck::index_t StrideD = 0; ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
...@@ -143,12 +132,12 @@ int main(int argc, char* argv[]) ...@@ -143,12 +132,12 @@ int main(int argc, char* argv[])
case 1: case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5}); d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{0});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5}); d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{0});
} }
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
...@@ -188,7 +177,15 @@ int main(int argc, char* argv[]) ...@@ -188,7 +177,15 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
RunUnfusedTest(a0_m_k.mData, b0_k_n.mData, d0_m_n.mData, e_m_n_host_result.mData, K, M, N); // RunUnfusedTest(a0_m_k.mData, b0_k_n.mData, d0_m_n.mData, e_m_n_host_result.mData, K, M,
// N);
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, e_m_n_host_result, AElementOp{}, BElementOp{}, CElementOp{});
ref_invoker.Run(ref_argument);
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment