Commit cc554d59 authored by wenjh's avatar wenjh
Browse files

Attmp fix build error


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 499dfb3d
...@@ -110,8 +110,9 @@ set(CUTLASS_TOOLS_INCLUDE_DIR ...@@ -110,8 +110,9 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python # Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
# NVIDIA MathDX include directory (from Python package install location) if(USE_CUDA)
if(NOT DEFINED MATHDX_INCLUDE_DIR) # NVIDIA MathDX include directory (from Python package install location)
if(NOT DEFINED MATHDX_INCLUDE_DIR)
execute_process( execute_process(
COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX OUTPUT_VARIABLE _PIP_SHOW_MATHDX
...@@ -127,9 +128,10 @@ if(NOT DEFINED MATHDX_INCLUDE_DIR) ...@@ -127,9 +128,10 @@ if(NOT DEFINED MATHDX_INCLUDE_DIR)
endif() endif()
set(MATHDX_LOCATION "${CMAKE_MATCH_1}") set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include") set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
endif() endif()
if(NOT EXISTS "${MATHDX_INCLUDE_DIR}") if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.") message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
endif()
endif() endif()
# Configure Transformer Engine library # Configure Transformer Engine library
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
************************************************************************/ ************************************************************************/
#include <cuda.h> #include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_pipeline.h> #include <cuda_pipeline.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
************************************************************************/ ************************************************************************/
#include <cuda.h> #include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_pipeline.h> #include <cuda_pipeline.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
************************************************************************/ ************************************************************************/
#include <cuda.h> #include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h> #include <nvrtc.h>
#ifndef __HIP_PLATFORM_AMD__
#include "nccl.h" #include "nccl.h"
#endif
#ifdef NVTE_WITH_CUBLASMP #ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h> #include <cublasmp.h>
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ #define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h> #include <cuda.h>
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if CUDA_VERSION > 12080 #if CUDA_VERSION > 12080
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment