Commit cc554d59 authored by wenjh's avatar wenjh
Browse files

Attmp fix build error


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 499dfb3d
......@@ -110,8 +110,9 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
if(USE_CUDA)
# NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
execute_process(
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
......@@ -127,9 +128,10 @@ if(NOT DEFINED MATHDX_INCLUDE_DIR)
endif()
set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
endif()
endif()
# Configure Transformer Engine library
......
......@@ -5,7 +5,9 @@
************************************************************************/
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
......
......@@ -5,7 +5,9 @@
************************************************************************/
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
......
......@@ -5,7 +5,9 @@
************************************************************************/
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_runtime.h>
......
......@@ -23,7 +23,9 @@
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#ifndef __HIP_PLATFORM_AMD__
#include "nccl.h"
#endif
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
......
......@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment