Commit 10d0a6de authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add more flag checks

parent 1ddc3ec7
......@@ -70,7 +70,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
endif()
endif()
if(DTYPES MATCHES "fp8" OR DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
endif()
......@@ -45,6 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_
PassThrough,
MultiplyAdd>>>&);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
......@@ -70,6 +71,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_m
PassThrough,
PassThrough,
MultiplyAdd>>>&);
#endif
// GEMM + Multiply + Add
template <typename ALayout,
......@@ -131,6 +133,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<D0DataType, float> && is_same_v<D1DataType, float> &&
is_same_v<EDataType, half_t>)
......@@ -150,6 +153,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
#endif
return op_ptrs;
}
......
......@@ -57,6 +57,7 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -96,6 +97,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
template <typename ADataType,
typename BDataType,
......@@ -176,6 +178,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
......@@ -224,6 +227,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
......
add_instance_library(device_gemm_multiply_add_instance
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
set(GEMM_MULTIPLY_ADD_INSTANCES)
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp)
endif()
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
endif()
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})
......@@ -214,6 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification,
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
// set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<CDataType, f8_t>)
......@@ -226,8 +227,11 @@ bool profile_gemm_splitk_impl(int do_verification,
}
else
{
#endif
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
}
#endif
if(tflops > best_tflops)
{
......
......@@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[])
const int StrideD1 = std::stoi(argv[14]);
const int StrideE = std::stoi(argv[15]);
using F8 = ck::f8_t;
using F16 = ck::half_t;
using F32 = float;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
using F8 = ck::f8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 &&
layout == MatrixLayout::MK_KN_MN_MN_MN)
{
......@@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{
return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
}
#endif
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
......
......@@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[])
using F32 = float;
using F16 = ck::half_t;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
using F8 = ck::f8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
}
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
......@@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{});
}
#endif
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment