"doc/vscode:/vscode.git/clone" did not exist on "bee6291dc20d7c6a94b822bbf57835d0f891bbeb"
Unverified Commit 5936aad9 authored by Mingtao Gu's avatar Mingtao Gu Committed by GitHub
Browse files

Merge pull request #3 from ROCm/i4_update

I4 update
parents 5662fc11 d06be51d
...@@ -39,7 +39,7 @@ using DeviceGemmV2Instance = ...@@ -39,7 +39,7 @@ using DeviceGemmV2Instance =
128, 128, 128, 128,
KPerBlock, 8, 32, KPerBlock, 8, 32,
32, 32, 32, 32,
2, 2, 4, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>,
......
...@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: default:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
} }
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re(sizeof(EDataType) *
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_img(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// Intermediate Value For E Real and Img // Intermediate Value For E Real and Img
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf_re1(sizeof(EDataType) *
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) *
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
...@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// set zero for intermediate values // set zero for intermediate values
e_device_buf_re1.SetZero(); e_device_buf_re1.SetZero();
e_device_buf_img1.SetZero(); e_device_buf_img1.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta}; auto cde_element_op = CDEElementOp{alpha, beta};
...@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// device operation // device operation
// For real Intermediate Value re_1 // For real Intermediate Value re_1
auto op = DeviceOpInstance{}; auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker(); auto invoker = op.MakeInvoker();
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_re1 =
b_device_buf_re.GetDeviceBuffer(), op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()}, b_device_buf_re.GetDeviceBuffer(),
e_device_buf_re1.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_re1.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re1)) if(!op.IsSupportedArgument(argument_re1))
{ {
...@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
alpha = -1.f; alpha = -1.f;
beta = 1.f; beta = 1.f;
...@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// For real Intermediate Value re_2 // For real Intermediate Value re_2
// auto op = DeviceOpInstance{}; // auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker(); // auto invoker = op.MakeInvoker();
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_re2 =
b_device_buf_img.GetDeviceBuffer(), op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()}, b_device_buf_img.GetDeviceBuffer(),
e_device_buf_re.GetDeviceBuffer(), std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re2)) if(!op.IsSupportedArgument(argument_re2))
{ {
...@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
...@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op = BElementOp{}; b_element_op = BElementOp{};
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), auto argument_img1 =
b_device_buf_img.GetDeviceBuffer(), op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()}, b_device_buf_img.GetDeviceBuffer(),
e_device_buf_img1.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_img1.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img1)) if(!op.IsSupportedArgument(argument_img1))
{ {
...@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), auto argument_img2 =
b_device_buf_re.GetDeviceBuffer(), op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()}, b_device_buf_re.GetDeviceBuffer(),
e_device_buf_img.GetDeviceBuffer(), std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
a_ms_ks_lengths, e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_strides, a_ms_ks_lengths,
b_ns_ks_lengths, a_ms_ks_strides,
b_ns_ks_strides, b_ns_ks_lengths,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths}, b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides}, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
e_ms_ns_lengths, std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_strides, e_ms_ns_lengths,
a_element_op, e_ms_ns_strides,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img2)) if(!op.IsSupportedArgument(argument_img2))
{ {
...@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
ck::index_t M = ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
...@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2; sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ; float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
...@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data()); e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
auto isRealOk = 0; auto isRealOk = 0;
auto isImgOk = 0; auto isImgOk = 0;
if(do_verification) if(do_verification)
{ {
...@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
auto ref_op = ReferenceOpInstance{}; auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument_re = auto ref_argument_re = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re); ref_invoker.Run(ref_argument_re);
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
{ {
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
...@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha = 1.f; alpha = 1.f;
beta = -1.f; beta = -1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
auto ref_argument_re1 = auto ref_argument_re1 = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re1); ref_invoker.Run(ref_argument_re1);
...@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
// Img Part Verification // Img Part Verification
Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
auto ref_argument_img = auto ref_argument_img = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img); ref_invoker.Run(ref_argument_img);
alpha = 1.f; alpha = 1.f;
beta = 1.f; beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta}; cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
...@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
auto ref_argument_img1 = auto ref_argument_img1 = ref_op.MakeArgument(
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img1); ref_invoker.Run(ref_argument_img1);
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
...@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) ...@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
} }
} }
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
return (isRealOk && isImgOk); return (isRealOk && isImgOk);
} }
......
...@@ -359,13 +359,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -359,13 +359,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
1,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
// Local prefetch 1 // Local prefetch 1
block_sync_lds(); block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
...@@ -381,6 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -381,6 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_scale_thread_buf[n0],
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
...@@ -406,10 +400,31 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -406,10 +400,31 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { b_scale_thread_copy.Run(b_scale_grid_desc,
c_thread_buf_per_scale.Clear(); b_scale_grid_buf,
static_for<0, KRepeat, 1>{}([&](auto k0) { b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
});
if((i + 2) % num_loop_per_scale == 0)
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec;
...@@ -426,20 +441,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -426,20 +441,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * xdlops_gemm.Run(
// type_convert<AccDataType>(a_scale_thread_buf[I0]) * a_thread_vec.template AsType<mfma_input_type>(),
type_convert<AccDataType>(b_scale_thread_buf[n0]); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
...@@ -459,32 +467,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -459,32 +467,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}), make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf, b_block_buf,
b_scale_thread_buf[n0],
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, k0, I0), make_tuple(n0, I0, k0, I0),
b_thread_buf); b_thread_buf);
}); });
}); });
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
});
if((i + 2) % num_loop_per_scale == 0)
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{}));
}
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -495,10 +483,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -495,10 +483,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
c_thread_buf_per_scale.Clear(); static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec;
...@@ -514,17 +501,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -514,17 +501,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<t>{}] * xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
// type_convert<AccDataType>(a_scale_thread_buf[I0]) * b_thread_vec.template AsType<mfma_input_type>(),
type_convert<AccDataType>(b_scale_thread_buf[n0]); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
......
...@@ -220,9 +220,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout, ...@@ -220,9 +220,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) > 128 * 128 * 64 * 2) MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
? 1 ? 2
: 2 : 1
: 2; : 2;
if(has_main_k_block_loop) if(has_main_k_block_loop)
......
...@@ -11,6 +11,98 @@ ...@@ -11,6 +11,98 @@
namespace ck { namespace ck {
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale)
{
constexpr int LO = 0x000f000f;
constexpr int HI = 0x00f000f0;
constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = amd_assembly_and_or_b32(q, LO, EX);
int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
constexpr int SUB = 0xE408E408; //-8
constexpr int MUL = 0x2c002c00; // 1/16
constexpr int ADD = 0xd480d480; //-79
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
asm volatile("v_pk_mul_f16 %0, %1, %2"
: "=v"(res.template AsType<half2_t>()(Number<0>{}))
: "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
asm volatile("v_pk_mul_f16 %0, %1, %2"
: "=v"(res.template AsType<half2_t>()(Number<1>{}))
: "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
return res.template AsType<half4_t>()[Number<0>{}];
}
// Further fuse the scale into inline assembly, sanity failed
#if 0
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half_t& scale)
{
constexpr int LO = 0x000f000f;
constexpr int HI = 0x00f000f0;
constexpr int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
// int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
int lo = amd_assembly_and_or_b32(q, LO, EX);
int hi = amd_assembly_and_or_b32(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
// constexpr int SUB = 0xE408E408; //-8
// constexpr int MUL = 0x2c002c00; // 1/16
// constexpr int ADD = 0xd480d480; //-79
constexpr half_t SUB = bit_cast<half_t>(static_cast<uint16_t>(0xE408));
constexpr half_t MUL = bit_cast<half_t>(static_cast<uint16_t>(0x2c00));
constexpr half_t ADD = bit_cast<half_t>(static_cast<uint16_t>(0xd480));
vector_type<half_t, 2> scale_2;
scale_2.template AsType<half_t>()(Number<0>{}) = scale;
scale_2.template AsType<half_t>()(Number<1>{}) = scale;
vector_type<half_t, 2> sub_2;
sub_2.template AsType<half_t>()(Number<0>{}) = SUB * scale;
sub_2.template AsType<half_t>()(Number<1>{}) = SUB * scale;
vector_type<half_t, 2> mul_2;
mul_2.template AsType<half_t>()(Number<0>{}) = MUL * scale;
mul_2.template AsType<half_t>()(Number<1>{}) = MUL * scale;
vector_type<half_t, 2> add_2;
add_2.template AsType<half_t>()(Number<0>{}) = ADD * scale;
add_2.template AsType<half_t>()(Number<1>{}) = ADD * scale;
vector_type<half_t, 4> res;
res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_fma_f16(bit_cast<half2_t>(lo),
scale_2.template AsType<half2_t>()(Number<0>{}),
sub_2.template AsType<half2_t>()(Number<0>{}));
res.template AsType<half2_t>()(Number<1>{}) =
amd_assembly_pk_fma_f16(bit_cast<half2_t>(hi),
mul_2.template AsType<half2_t>()(Number<0>{}),
add_2.template AsType<half2_t>()(Number<0>{}));
// asm volatile("v_pk_mul_f16 %0, %1, %2"
// : "=v"(res.template AsType<half2_t>()(Number<0>{}))
// : "v"(res.template AsType<half2_t>()(Number<0>{})), "v"(scale));
// asm volatile("v_pk_mul_f16 %0, %1, %2"
// : "=v"(res.template AsType<half2_t>()(Number<1>{}))
// : "v"(res.template AsType<half2_t>()(Number<1>{})), "v"(scale));
return res.template AsType<half4_t>()[Number<0>{}];
}
#endif
__host__ __device__ inline half4_t pki4_to_half4(int q) __host__ __device__ inline half4_t pki4_to_half4(int q)
{ {
constexpr int LO = 0x000f000f; constexpr int LO = 0x000f000f;
...@@ -119,6 +211,69 @@ struct PassThroughPack8 ...@@ -119,6 +211,69 @@ struct PassThroughPack8
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x)); result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8); result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}];
#else
vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
}
constexpr const static bool is_pack8_invocable = true;
};
struct DequantPack8
{
template <typename Y, typename X, typename Z>
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
__host__ __device__ constexpr void
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
{
#if 0
int x_permute = 0;
int bits4_0 = (bit_cast<int>(x) >> 0) & 0xF;
int bits4_1 = (bit_cast<int>(x) >> 4) & 0xF;
int bits4_2 = (bit_cast<int>(x) >> 8) & 0xF;
int bits4_3 = (bit_cast<int>(x) >> 12) & 0xF;
int bits4_4 = (bit_cast<int>(x) >> 16) & 0xF;
int bits4_5 = (bit_cast<int>(x) >> 20) & 0xF;
int bits4_6 = (bit_cast<int>(x) >> 24) & 0xF;
int bits4_7 = (bit_cast<int>(x) >> 28) & 0xF;
x_permute |= (bits4_1 << 0);
x_permute |= (bits4_3 << 4);
x_permute |= (bits4_5 << 8);
x_permute |= (bits4_7 << 12);
x_permute |= (bits4_0 << 16);
x_permute |= (bits4_2 << 20);
x_permute |= (bits4_4 << 24);
x_permute |= (bits4_6 << 28);
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(x_permute, z);
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4_scale(x_permute >> 8, z);
y = result.template AsType<half8_t>()[Number<0>{}];
#elif 1
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<1>{}) =
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z);
y = result.template AsType<half8_t>()[Number<0>{}]; y = result.template AsType<half8_t>()[Number<0>{}];
#else #else
vector_type<half_t, 8> dst; vector_type<half_t, 8> dst;
......
...@@ -1914,7 +1914,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1914,7 +1914,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0)); make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset, 0));
constexpr auto b_scale_thread_slice_copy_step = constexpr auto b_scale_thread_slice_copy_step =
make_tuple(make_multi_index(NWaves * NPerXdl, 0), make_multi_index(-NPerBlock, 1)); make_tuple(make_multi_index(NWaves * NPerXdl, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, 1));
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock; const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
......
...@@ -1252,6 +1252,237 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1252,6 +1252,237 @@ struct ThreadwiseTensorSliceTransfer_v4
}); });
} }
// Fuse scale
template <typename SrcRefToOriginDisplacement,
typename DstOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcBuffer& src_buf,
const DstData& scale,
const DstDesc&,
const DstOriginIdx&,
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
// scalar per access of each dim
constexpr auto src_scalar_per_access = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<SrcScalarPerVector>{};
}
else
{
return Number<1>{};
}
},
Number<nDim>{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<1>{};
}
else
{
return Number<0>{};
}
},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
#endif
// src coordinate
constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_step =
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer())
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
is_src_valid);
}
else if constexpr(SrcBuffer::IsStaticBuffer())
{
static_assert(false, "");
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
});
}
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
vector_type<DstData, 2> scale_vector;
scale_vector.template AsType<DstData>()(Number<0>{}) = scale;
scale_vector.template AsType<DstData>()(Number<1>{}) = scale;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using scale_v_t = typename vector_type_maker_t<DstData, 2>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::DequantPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i],
scale_vector.template AsType<scale_v_t>()[Number<0>{}]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, f8_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 2;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack2{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
});
}
template <typename SrcSliceMoveStepIdx> template <typename SrcSliceMoveStepIdx>
__device__ void MoveSrcSliceWindow(const SrcDesc&, __device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx) const SrcSliceMoveStepIdx& src_slice_move_step_idx)
......
...@@ -21,14 +21,14 @@ inline __device__ int amd_assembly_and_or_b32(int a, int b, int d) ...@@ -21,14 +21,14 @@ inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
{ {
half2_t d; half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c));
return d; return d;
} }
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
{ {
half2_t c; half2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c; return c;
} }
......
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_HIP_FLAGS="-save-temps -gline-tables-only -Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_HIP_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment