Commit a69e2f96 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'miopen_fp8' into mobilenet_fp8

parents 752cb65a 8c5678e0
...@@ -256,6 +256,7 @@ include(CheckLibraryExists) ...@@ -256,6 +256,7 @@ include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION) get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION) get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenSetTensorCastType" "${MIOPEN_LOCATION}" MIOPEN_HAS_BETA_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning # Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API) check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
...@@ -263,6 +264,10 @@ check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOC ...@@ -263,6 +264,10 @@ check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOC
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API) check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "") set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
set(MIGRAPHX_USE_MIOPEN_BETA_API "${MIOPEN_HAS_BETA_API}" CACHE BOOL "")
if(MIGRAPHX_USE_MIOPEN_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIOPEN_BETA_API -DMIOPEN_DONT_USE_HIP_RUNTIME_HEADERS)
endif()
if(MIGRAPHX_USE_FIND_2_API) if(MIGRAPHX_USE_FIND_2_API)
check_library_exists(MIOpen "miopenSetFindOptionPreallocatedTensor" "${MIOPEN_LOCATION}" HAS_PREALLOCATION_API) check_library_exists(MIOpen "miopenSetFindOptionPreallocatedTensor" "${MIOPEN_LOCATION}" HAS_PREALLOCATION_API)
......
...@@ -143,6 +143,12 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os) ...@@ -143,6 +143,12 @@ inline tensor_descriptor make_tensor(const migraphx::shape& os)
d = miopenInt32; d = miopenInt32;
else if(s.type() == shape::int8_type) else if(s.type() == shape::int8_type)
d = miopenInt8; d = miopenInt8;
else if(s.type() == shape::fp8e4m3fnuz_type)
#ifdef MIOPEN_BETA_API
d = miopenFloat8;
#else
MIGRAPHX_THROW("MIOPEN doesn't have API to support FP8");
#endif
else else
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type"); MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data()); miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
......
...@@ -125,6 +125,8 @@ struct find_add_layernorm ...@@ -125,6 +125,8 @@ struct find_add_layernorm
} }
}; };
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
struct pre_gemm_softmax_gemm : gemm_softmax_gemm struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{ {
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; } std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
...@@ -201,6 +203,8 @@ struct find_gemm_softmax_gemm ...@@ -201,6 +203,8 @@ struct find_gemm_softmax_gemm
} }
}; };
#endif
} // namespace } // namespace
void prefuse_ops::apply(module_pass_manager& mpm) const void prefuse_ops::apply(module_pass_manager& mpm) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment