CMakeLists.txt 14.6 KB
Newer Older
1
cmake_minimum_required(VERSION 3.26)
bnellnm's avatar
bnellnm committed
2

3
4
5
6
7
8
9
10
11
12
13
# When building directly using CMake, make sure you run the install step
# (it places the .so files in the correct location).
#
# Example:
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. ..
# cmake --build . --target install
#
# If you want to only build one target, make sure to install it manually:
# cmake --build . --target _C
# cmake --install . --component _C
bnellnm's avatar
bnellnm committed
14
15
project(vllm_extensions LANGUAGES CXX)

16
17
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
18

bnellnm's avatar
bnellnm committed
19
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
20
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
bnellnm's avatar
bnellnm committed
21
22
23

include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

24
25
26
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")

27
28
29
# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

bnellnm's avatar
bnellnm committed
30
31
32
33
#
# Supported python versions.  These versions will be searched in order, the
# first match will be selected.  These should be kept in sync with setup.py.
#
34
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
bnellnm's avatar
bnellnm committed
35
36
37
38
39

# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")

# Supported AMD GPU architectures.
40
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
bnellnm's avatar
bnellnm committed
41
42
43
44
45
46
47
48
49
50
51

#
# Supported/expected torch versions for CUDA/ROCm.
#
# Currently, having an incorrect pytorch version results in a warning
# rather than an error.
#
# Note: the CUDA torch version is derived from pyproject.toml and various
# requirements.txt files and should be kept consistent.  The ROCm torch
# versions are derived from Dockerfile.rocm
#
52
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
53
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
bnellnm's avatar
bnellnm committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

#
# Try to find python package with an executable that exactly matches
# `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions.
#
if (VLLM_PYTHON_EXECUTABLE)
  find_python_from_executable(${VLLM_PYTHON_EXECUTABLE} "${PYTHON_SUPPORTED_VERSIONS}")
else()
  message(FATAL_ERROR
    "Please set VLLM_PYTHON_EXECUTABLE to the path of the desired python version"
    " before running cmake configure.")
endif()

#
# Update cmake's `CMAKE_PREFIX_PATH` with torch location.
#
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")

72
73
# Ensure the 'nvcc' command is in the PATH
find_program(NVCC_EXECUTABLE nvcc)
74
if (CUDA_FOUND AND NOT NVCC_EXECUTABLE)
75
76
77
    message(FATAL_ERROR "nvcc not found")
endif()

bnellnm's avatar
bnellnm committed
78
79
80
81
82
83
84
85
#
# Import torch cmake configuration.
# Torch also imports CUDA (and partially HIP) languages with some customizations,
# so there is no need to do this explicitly with check_language/enable_language,
# etc.
#
find_package(Torch REQUIRED)

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#
message(STATUS "Enabling core extension.")

# Define _core_C extension
#  built for (almost) every target platform, (excludes TPU and Neuron)

set(VLLM_EXT_SRC
  "csrc/core/torch_bindings.cpp")

define_gpu_extension_target(
  _core_C
  DESTINATION vllm
  LANGUAGE CXX
  SOURCES ${VLLM_EXT_SRC}
  COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
  USE_SABI 3
  WITH_SOABI)

104
105
106
107
108
109
110
111
#
# Forward the non-CUDA device extensions to external CMake scripts.
#
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
    NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
    if (VLLM_TARGET_DEVICE STREQUAL "cpu")
        include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
    else()
112
        return()
113
114
115
116
    endif()
    return()
endif()

bnellnm's avatar
bnellnm committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#
# Set up GPU language and check the torch version and warn if it isn't
# what is expected.
#
if (NOT HIP_FOUND AND CUDA_FOUND)
  set(VLLM_GPU_LANG "CUDA")

  if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA})
    message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} "
      "expected for CUDA build, saw ${Torch_VERSION} instead.")
  endif()
elseif(HIP_FOUND)
  set(VLLM_GPU_LANG "HIP")

  # Importing torch recognizes and sets up some HIP/ROCm configuration but does
  # not let cmake recognize .hip files. In order to get cmake to understand the
  # .hip extension automatically, HIP must be enabled explicitly.
  enable_language(HIP)

136
137
138
  # ROCm 5.X and 6.X
  if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
      NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
139
    message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
140
      "expected for ROCm build, saw ${Torch_VERSION} instead.")
bnellnm's avatar
bnellnm committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
  endif()
else()
  message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif()

#
# Override the GPU architectures detected by cmake/torch and filter them by
# the supported versions for the current language.
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
#
override_gpu_arches(VLLM_GPU_ARCHES
  ${VLLM_GPU_LANG}
  "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")

#
# Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`.
# The final set of arches is stored in `VLLM_GPU_FLAGS`.
#
get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG})

#
# Set nvcc parallelism.
#
if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
  list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()

169
170
171
172
173

#
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
# Configure it to place files in vllm/.deps, in order to play nicely with sccache.
#
174
include(FetchContent)
175
176
177
178
get_filename_component(PROJECT_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" ABSOLUTE)
file(MAKE_DIRECTORY "${FETCHCONTENT_BASE_DIR}")
set(FETCHCONTENT_BASE_DIR "${PROJECT_ROOT_DIR}/.deps")
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
179

bnellnm's avatar
bnellnm committed
180
#
181
# Define other extension targets
bnellnm's avatar
bnellnm committed
182
183
184
185
186
187
188
189
190
191
192
193
194
#

#
# _C extension
#

set(VLLM_EXT_SRC
  "csrc/cache_kernels.cu"
  "csrc/attention/attention_kernels.cu"
  "csrc/pos_encoding_kernels.cu"
  "csrc/activation_kernels.cu"
  "csrc/layernorm_kernels.cu"
  "csrc/quantization/gptq/q_gemm.cu"
195
  "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
196
  "csrc/quantization/fp8/common.cu"
bnellnm's avatar
bnellnm committed
197
198
  "csrc/cuda_utils_kernels.cu"
  "csrc/moe_align_block_size_kernels.cu"
199
  "csrc/prepare_inputs/advance_step.cu"
200
  "csrc/torch_bindings.cpp")
bnellnm's avatar
bnellnm committed
201
202

if(VLLM_GPU_LANG STREQUAL "CUDA")
203
  SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
204
205
206
207

  # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
  set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")

208
  FetchContent_Declare(
209
        cutlass
210
        GIT_REPOSITORY https://github.com/nvidia/cutlass.git
211
        GIT_TAG v3.5.1
212
        GIT_PROGRESS TRUE
213
214
215
216
217

        # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
        # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
        # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
        GIT_SHALLOW TRUE
218
219
220
  )
  FetchContent_MakeAvailable(cutlass)

bnellnm's avatar
bnellnm committed
221
  list(APPEND VLLM_EXT_SRC
222
223
    "csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
    "csrc/mamba/causal_conv1d/causal_conv1d.cu"
James Fleming's avatar
James Fleming committed
224
    "csrc/quantization/aqlm/gemm_kernels.cu"
bnellnm's avatar
bnellnm committed
225
    "csrc/quantization/awq/gemm_kernels.cu"
226
227
    "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
    "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
228
    "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
229
230
    "csrc/quantization/gptq_marlin/gptq_marlin.cu"
    "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
231
    "csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
232
    "csrc/quantization/gguf/gguf_kernel.cu"
233
    "csrc/quantization/fp8/fp8_marlin.cu"
234
    "csrc/custom_all_reduce.cu"
235
    "csrc/permute_cols.cu"
236
237
238
    "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
    "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
    "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
239
240
241
242
243

  #
  # The CUTLASS kernels for Hopper require sm90a to be enabled.
  # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
  # That adds an extra 17MB to compiled binary, so instead we selectively enable it.
244
  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
245
    set_source_files_properties(
246
          "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
247
248
249
250
          PROPERTIES
          COMPILE_FLAGS
          "-gencode arch=compute_90a,code=sm_90a")
  endif()
251

252

253
  #
254
  # Machete kernels
255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
  # The machete kernels only work on hopper and require CUDA 12.0 or later.
  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
    #
    # For the Machete kernels we automatically generate sources for various 
    # preselected input type pairs and schedules.
    # Generate sources:
    execute_process(
      COMMAND ${CMAKE_COMMAND} -E env 
      PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH 
        ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
      RESULT_VARIABLE machete_generation_result
      OUTPUT_VARIABLE machete_generation_output
      OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
      ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
    )

    if (NOT machete_generation_result EQUAL 0)
      message(FATAL_ERROR "Machete generation failed."
                          " Result: \"${machete_generation_result}\"" 
                          "\nCheck the log for details: "
                          "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
    else()
      message(STATUS "Machete generation completed successfully.")
    endif()
280

281
282
283
284
    # Add machete generated sources
    file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
    list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
    message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")
285
286
287
288
289
290
291
292

    set_source_files_properties(
          ${MACHETE_GEN_SOURCES}
          PROPERTIES
          COMPILE_FLAGS
          "-gencode arch=compute_90a,code=sm_90a")
  endif()

293
294
295
  # Add pytorch binding for machete (add on even CUDA < 12.0 so that we can
  #  raise an error if the user that this was built with an incompatible 
  #  CUDA version)
296
297
  list(APPEND VLLM_EXT_SRC
    csrc/quantization/machete/machete_pytorch.cu)
bnellnm's avatar
bnellnm committed
298
299
endif()

300
message(STATUS "Enabling C extension.")
bnellnm's avatar
bnellnm committed
301
302
303
304
305
306
307
define_gpu_extension_target(
  _C
  DESTINATION vllm
  LANGUAGE ${VLLM_GPU_LANG}
  SOURCES ${VLLM_EXT_SRC}
  COMPILE_FLAGS ${VLLM_GPU_FLAGS}
  ARCHITECTURES ${VLLM_GPU_ARCHES}
308
  INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
309
  USE_SABI 3
bnellnm's avatar
bnellnm committed
310
311
  WITH_SOABI)

312
313
314
315
316
317
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses 
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the 
# driver API. This causes problems when linking with earlier versions of CUDA.
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

bnellnm's avatar
bnellnm committed
318
319
320
321
322
#
# _moe_C extension
#

set(VLLM_MOE_EXT_SRC
323
  "csrc/moe/torch_bindings.cpp"
324
  "csrc/moe/topk_softmax_kernels.cu")
bnellnm's avatar
bnellnm committed
325

326
327
if(VLLM_GPU_LANG STREQUAL "CUDA")
  list(APPEND VLLM_MOE_EXT_SRC
328
329
330
331
332
      "csrc/moe/marlin_kernels/marlin_moe_kernel.h"
      "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
      "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
      "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
      "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
333
334
335
      "csrc/moe/marlin_moe_ops.cu")
endif()

336
message(STATUS "Enabling moe extension.")
bnellnm's avatar
bnellnm committed
337
338
339
340
341
342
343
define_gpu_extension_target(
  _moe_C
  DESTINATION vllm
  LANGUAGE ${VLLM_GPU_LANG}
  SOURCES ${VLLM_MOE_EXT_SRC}
  COMPILE_FLAGS ${VLLM_GPU_FLAGS}
  ARCHITECTURES ${VLLM_GPU_ARCHES}
344
  USE_SABI 3
bnellnm's avatar
bnellnm committed
345
346
  WITH_SOABI)

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
if(VLLM_GPU_LANG STREQUAL "HIP")
  #
  # _rocm_C extension
  #
  set(VLLM_ROCM_EXT_SRC
    "csrc/rocm/torch_bindings.cpp"
    "csrc/rocm/attention.cu")

  define_gpu_extension_target(
    _rocm_C
    DESTINATION vllm
    LANGUAGE ${VLLM_GPU_LANG}
    SOURCES ${VLLM_ROCM_EXT_SRC}
    COMPILE_FLAGS ${VLLM_GPU_FLAGS}
    ARCHITECTURES ${VLLM_GPU_ARCHES}
    USE_SABI 3
    WITH_SOABI)
endif()

366
367
368
369
# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
  return()
endif ()
bnellnm's avatar
bnellnm committed
370

371
372
373
374
375
376
377
378
379
380
#
# Build vLLM flash attention from source
#
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# If no component is specified, vllm-flash-attn is still installed.
bnellnm's avatar
bnellnm committed
381

382
383
384
385
386
387
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
# This is to enable local development of vllm-flash-attn within vLLM.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
  set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
388
endif()
389

390
391
392
393
394
395
396
397
398
if(VLLM_FLASH_ATTN_SRC_DIR)
  FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
else()
  FetchContent_Declare(
          vllm-flash-attn
          GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
          GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
          GIT_PROGRESS TRUE
  )
bnellnm's avatar
bnellnm committed
399
endif()
400
401
402
403

# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)

404
405
406
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)

# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")

# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)

# Copy over the vllm-flash-attn python files
install(
        DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
        DESTINATION vllm/vllm_flash_attn
        COMPONENT vllm_flash_attn_c
        FILES_MATCHING PATTERN "*.py"
)

# Nothing after vllm-flash-attn, see comment about macros above