Commit 674f74ad authored by myamlak's avatar myamlak
Browse files

Test fixes.

parent 14bd1430
...@@ -60,7 +60,7 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -60,7 +60,7 @@ struct ReferenceCGemm : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_mk_kn_mn_real = [&](auto m, auto n) { auto f_mk_kn_mn_real = [&](auto m, auto n) {
const int K = arg.a_m_k_real_.mDesc.GetLengths()[1]; const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
{ {
...@@ -69,7 +69,7 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -69,7 +69,7 @@ struct ReferenceCGemm : public device::BaseOperator
float v_acc = 0; float v_acc = 0;
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
float v_a_real; float v_a_real;
float v_b_real; float v_b_real;
...@@ -92,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -92,7 +92,7 @@ struct ReferenceCGemm : public device::BaseOperator
}; };
auto f_mk_kn_mn_imag = [&](auto m, auto n) { auto f_mk_kn_mn_imag = [&](auto m, auto n) {
const int K = arg.a_m_k_real_.mDesc.GetLengths()[1]; const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1];
if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1])
{ {
...@@ -101,7 +101,7 @@ struct ReferenceCGemm : public device::BaseOperator ...@@ -101,7 +101,7 @@ struct ReferenceCGemm : public device::BaseOperator
float v_acc = 0; float v_acc = 0;
for(int k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
float v_a_real; float v_a_real;
float v_b_real; float v_b_real;
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceCGemmNoOpPtr = using DeviceCGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::DeviceCGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
...@@ -48,9 +48,9 @@ int main() ...@@ -48,9 +48,9 @@ int main()
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true; bool res = true;
std::vector<DeviceCGemmNoOpPtr> gemmPtrs; std::vector<DeviceCGemmNoOpPtr> cgemmPtrs;
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs) for(auto& cgemmPtr : cgemmPtrs)
...@@ -76,7 +76,7 @@ int main() ...@@ -76,7 +76,7 @@ int main()
RowMajor, RowMajor,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>{}(gemmPtr); PassThrough>{}(cgemmPtr);
} }
cgemmPtrs.clear(); cgemmPtrs.clear();
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceCGemmNoOpPtr = using DeviceCGemmNoOpPtr =
ck::tensor_operation::device::DevicecgemmPtr<ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::DeviceCGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
...@@ -50,10 +50,7 @@ int main() ...@@ -50,10 +50,7 @@ int main()
bool res = true; bool res = true;
std::vector<DeviceCGemmNoOpPtr> cgemmPtrs; std::vector<DeviceCGemmNoOpPtr> cgemmPtrs;
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
...@@ -72,10 +69,6 @@ int main() ...@@ -72,10 +69,6 @@ int main()
} }
cgemmPtrs.clear(); cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
...@@ -94,10 +87,6 @@ int main() ...@@ -94,10 +87,6 @@ int main()
} }
cgemmPtrs.clear(); cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
...@@ -116,14 +105,8 @@ int main() ...@@ -116,14 +105,8 @@ int main()
} }
cgemmPtrs.clear(); cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs) for(auto& cgemmPtr : cgemmPtrs)
{ {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceCGemmNoOpPtr = using DeviceCGemmNoOpPtr =
ck::tensor_operation::device::DevicecgemmPtr<ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::DeviceCGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
...@@ -54,10 +54,7 @@ int main() ...@@ -54,10 +54,7 @@ int main()
bool res = true; bool res = true;
std::vector<DeviceCGemmNoOpPtr> cgemmPtrs; std::vector<DeviceCGemmNoOpPtr> cgemmPtrs;
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
...@@ -76,10 +73,6 @@ int main() ...@@ -76,10 +73,6 @@ int main()
} }
cgemmPtrs.clear(); cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
...@@ -98,10 +91,6 @@ int main() ...@@ -98,10 +91,6 @@ int main()
} }
cgemmPtrs.clear(); cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
...@@ -120,10 +109,6 @@ int main() ...@@ -120,10 +109,6 @@ int main()
} }
cgemmPtrs.clear(); cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance:: ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs); add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
......
...@@ -77,21 +77,23 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, ...@@ -77,21 +77,23 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * A_real.mDesc.GetElementSpace());
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * A_imag.mDesc.GetElementSpace());
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * B_real.mDesc.GetElementSpace());
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * B_imag.mDesc.GetElementSpace());
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace());
DeviceMem aux_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); DeviceMem aux_device_buf(sizeof(CDataType) * Aux.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(A.mData.data()); a_m_k_real_device_buf.ToDevice(A_real.mData.data());
b_k_n_device_buf.ToDevice(B.mData.data()); a_m_k_imag_device_buf.ToDevice(A_imag.mData.data());
b_k_n_real_device_buf.ToDevice(B_real.mData.data());
b_k_n_imag_device_buf.ToDevice(B_imag.mData.data());
auto invoker_ptr = cgemmPtr->MakeInvokerPointer(); auto invoker_ptr = cgemmPtr->MakeInvokerPointer();
auto argument_ptr = cgemmPtr->MakeArgumentPointer( auto argument_ptr = cgemmPtr->MakeArgumentPointer(
static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_m_k_imag_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
...@@ -255,7 +257,7 @@ struct TestCGemm ...@@ -255,7 +257,7 @@ struct TestCGemm
if(std::is_same<CDataType, float>::value) if(std::is_same<CDataType, float>::value)
{ {
res = ck::utils::check_err(c_device_real.mData, c_host_real.mData) && res = ck::utils::check_err(c_device_real.mData, c_host_real.mData) &&
ck::utils::check_err(c_device_real.mData, c_host.mData); ck::utils::check_err(c_device_imag.mData, c_host_imag.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, ck::half_t>::value) else if(std::is_same<CDataType, ck::half_t>::value)
...@@ -326,15 +328,13 @@ struct TestCGemmBF16 ...@@ -326,15 +328,13 @@ struct TestCGemmBF16
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<float> b_k_n_imag_fp32( Tensor<float> b_k_n_imag_fp32(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<float> c_m_n_host_real_fp32( Tensor<float> c_m_n_real_host_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> c_m_n_host_imag_fp32( Tensor<float> c_m_n_imag_host_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> c_m_n_device_real_fp32( Tensor<float> c_m_n_real_device_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> c_m_n_device_imag_fp32( Tensor<float> c_m_n_imag_device_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> aux_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k_real_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5}); a_m_k_real_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
...@@ -361,8 +361,7 @@ struct TestCGemmBF16 ...@@ -361,8 +361,7 @@ struct TestCGemmBF16
c_m_n_real_host_fp32, c_m_n_real_host_fp32,
c_m_n_imag_host_fp32, c_m_n_imag_host_fp32,
c_m_n_real_device_fp32, c_m_n_real_device_fp32,
c_m_n_imag_device_fp32, c_m_n_imag_device_fp32);
aux_fp32);
} }
auto operator()(DeviceCGemmPtr_& cgemmPtr) auto operator()(DeviceCGemmPtr_& cgemmPtr)
...@@ -392,32 +391,31 @@ struct TestCGemmBF16 ...@@ -392,32 +391,31 @@ struct TestCGemmBF16
Tensor<float>& c_imag_host_fp32 = std::get<12>(host_tensors); Tensor<float>& c_imag_host_fp32 = std::get<12>(host_tensors);
Tensor<float>& c_real_device_fp32 = std::get<13>(host_tensors); Tensor<float>& c_real_device_fp32 = std::get<13>(host_tensors);
Tensor<float>& c_imag_device_fp32 = std::get<14>(host_tensors); Tensor<float>& c_imag_device_fp32 = std::get<14>(host_tensors);
Tensor<float>& aux_fp32 = std::get<15>(host_tensors);
auto a_element_op = AElementwiseOperation{}; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{}; auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{}; auto c_element_op = CElementwiseOperation{};
// use fp32 host kernel to verify bf16 device kernel // use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance = using ReferenceCGemmInstance =
ck::tensor_operation::host::ReferenceCGemm<float, ck::tensor_operation::host::ReferenceCGemm<float,
float, float,
float, float,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
ck::gemm_util::RunHostCGEMM<ReferenceCGemmInstance>(a_real_fp32, ck::cgemm_util::RunHostCGEMM<ReferenceCGemmInstance>(a_real_fp32,
a_imag_fp32, a_imag_fp32,
b_real_fp32, b_real_fp32,
b_imag_fp32, b_imag_fp32,
c_real_host_fp32, c_real_host_fp32,
c_imag_fp32, c_imag_host_fp32,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
// Act // Act
ck::gemm_util::RunDeviceCGEMM(cgemmPtr, ck::cgemm_util::RunDeviceCGEMM(cgemmPtr,
params, params,
a_real_bf16, a_real_bf16,
a_imag_bf16, a_imag_bf16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment