Unverified Commit 2ea75bd6 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Resolve some data type issues and cmake policy. (#940)

* split the types in gemm_bilinear instances, add condition to cmake policy

* fix syntax

* split the data types in batchnorm examples

* fix the batchnorm_bwd test

* fix types in the batchnorm_bwd test
parent c9553832
cmake_minimum_required(VERSION 3.14) cmake_minimum_required(VERSION 3.14)
cmake_policy(SET CMP0140 NEW) if(POLICY CMP0140)
# policies CMP0140 not known to CMake until 3.25
cmake_policy(SET CMP0140 NEW)
endif()
# This has to be initialized before the project() command appears # This has to be initialized before the project() command appears
# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE # Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE
......
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row, Row,
...@@ -68,7 +68,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance ...@@ -68,7 +68,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row, Row,
...@@ -120,7 +121,7 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances( ...@@ -120,7 +121,7 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
Bilinear>>>& instances); Bilinear>>>& instances);
#endif
// GEMM + Bilinear // GEMM + Bilinear
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -158,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -158,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>) is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{ {
...@@ -187,7 +188,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -187,7 +188,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs); op_ptrs);
} }
} }
else if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> && #endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>) is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
...@@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs); add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
...@@ -220,4 +223,3 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -220,4 +223,3 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -70,10 +70,23 @@ class TestBatchNormBwdRank4 : public ::testing::Test ...@@ -70,10 +70,23 @@ class TestBatchNormBwdRank4 : public ::testing::Test
} }
}; };
using KernelTypes = ::testing::Types<std::tuple<F16, F32, F32, F32, F16, F32, F32>, using KernelTypes = ::testing::Types<
std::tuple<F32, F32, F32, F32, F32, F32, F32>, #ifdef CK_ENABLE_FP16
std::tuple<BF16, F32, F32, F32, BF16, F32, F32>, std::tuple<F16, F32, F32, F32, F16, F32, F32>
std::tuple<F64, F64, F64, F64, F64, F64, F64>>; #endif
#ifdef CK_ENABLE_FP32
,
std::tuple<F32, F32, F32, F32, F32, F32, F32>
#endif
#ifdef CK_ENABLE_BF16
,
std::tuple<BF16, F32, F32, F32, BF16, F32, F32>
#endif
#ifdef CK_ENABLE_FP64
,
std::tuple<F64, F64, F64, F64, F64, F64, F64>
#endif
>;
TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes); TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes);
......
...@@ -87,10 +87,23 @@ class TestBatchNormFwdRank4 : public ::testing::Test ...@@ -87,10 +87,23 @@ class TestBatchNormFwdRank4 : public ::testing::Test
} }
}; };
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, F32>, using KernelTypes = ::testing::Types<
std::tuple<F32, F32, F32, F32, F32, F32>, #ifdef CK_ENABLE_FP16
std::tuple<BF16, BF16, F32, BF16, BF16, F32>, std::tuple<F16, F16, F32, F16, F16, F32>
std::tuple<F64, F64, F64, F64, F64, F64>>; #endif
#ifdef CK_ENABLE_FP32
,
std::tuple<F32, F32, F32, F32, F32, F32>
#endif
#ifdef CK_ENABLE_BF16
,
std::tuple<BF16, BF16, F32, BF16, BF16, F32>
#endif
#ifdef CK_ENABLE_FP64
,
std::tuple<F64, F64, F64, F64, F64, F64>
#endif
>;
TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes); TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes);
......
...@@ -67,10 +67,23 @@ class TestBatchNormInferRank4 : public ::testing::Test ...@@ -67,10 +67,23 @@ class TestBatchNormInferRank4 : public ::testing::Test
} }
}; };
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, F32>, using KernelTypes = ::testing::Types<
std::tuple<F32, F32, F32, F32, F32, F32>, #ifdef CK_ENABLE_FP16
std::tuple<BF16, BF16, F32, BF16, BF16, F32>, std::tuple<F16, F16, F32, F16, F16, F32>
std::tuple<F64, F64, F64, F64, F64, F64>>; #endif
#ifdef CK_ENABLE_FP32
,
std::tuple<F32, F32, F32, F32, F32, F32>
#endif
#ifdef CK_ENABLE_BF16
,
std::tuple<BF16, BF16, F32, BF16, BF16, F32>
#endif
#ifdef CK_ENABLE_FP64
,
std::tuple<F64, F64, F64, F64, F64, F64>
#endif
>;
TYPED_TEST_SUITE(TestBatchNormInferRank4, KernelTypes); TYPED_TEST_SUITE(TestBatchNormInferRank4, KernelTypes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment