CMakeLists.txt 21.2 KB
Newer Older
1
cmake_minimum_required(VERSION 3.26)
bnellnm's avatar
bnellnm committed
2

3
4
5
6
7
8
9
10
11
12
13
# When building directly using CMake, make sure you run the install step
# (it places the .so files in the correct location).
#
# Example:
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. ..
# cmake --build . --target install
#
# If you want to only build one target, make sure to install it manually:
# cmake --build . --target _C
# cmake --install . --component _C
bnellnm's avatar
bnellnm committed
14
15
project(vllm_extensions LANGUAGES CXX)

16
17
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
18

bnellnm's avatar
bnellnm committed
19
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
20
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
bnellnm's avatar
bnellnm committed
21
22
23

include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

24
25
26
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")

27
28
29
# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

bnellnm's avatar
bnellnm committed
30
31
32
33
#
# Supported python versions.  These versions will be searched in order, the
# first match will be selected.  These should be kept in sync with setup.py.
#
34
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
bnellnm's avatar
bnellnm committed
35
36

# Supported NVIDIA architectures.
37
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
bnellnm's avatar
bnellnm committed
38
39

# Supported AMD GPU architectures.
40
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
bnellnm's avatar
bnellnm committed
41
42
43
44
45
46
47
48
49
50
51

#
# Supported/expected torch versions for CUDA/ROCm.
#
# Currently, having an incorrect pytorch version results in a warning
# rather than an error.
#
# Note: the CUDA torch version is derived from pyproject.toml and various
# requirements.txt files and should be kept consistent.  The ROCm torch
# versions are derived from Dockerfile.rocm
#
52
53
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1")
bnellnm's avatar
bnellnm committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

#
# Try to find python package with an executable that exactly matches
# `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions.
#
if (VLLM_PYTHON_EXECUTABLE)
  find_python_from_executable(${VLLM_PYTHON_EXECUTABLE} "${PYTHON_SUPPORTED_VERSIONS}")
else()
  message(FATAL_ERROR
    "Please set VLLM_PYTHON_EXECUTABLE to the path of the desired python version"
    " before running cmake configure.")
endif()

#
# Update cmake's `CMAKE_PREFIX_PATH` with torch location.
#
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")

72
73
# Ensure the 'nvcc' command is in the PATH
find_program(NVCC_EXECUTABLE nvcc)
74
if (CUDA_FOUND AND NOT NVCC_EXECUTABLE)
75
76
77
    message(FATAL_ERROR "nvcc not found")
endif()

bnellnm's avatar
bnellnm committed
78
79
80
81
82
83
84
85
#
# Import torch cmake configuration.
# Torch also imports CUDA (and partially HIP) languages with some customizations,
# so there is no need to do this explicitly with check_language/enable_language,
# etc.
#
find_package(Torch REQUIRED)

86
87
88
89
90
91
92
93
#
# Forward the non-CUDA device extensions to external CMake scripts.
#
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
    NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
    if (VLLM_TARGET_DEVICE STREQUAL "cpu")
        include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
    else()
94
        return()
95
96
97
98
    endif()
    return()
endif()

bnellnm's avatar
bnellnm committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#
# Set up GPU language and check the torch version and warn if it isn't
# what is expected.
#
if (NOT HIP_FOUND AND CUDA_FOUND)
  set(VLLM_GPU_LANG "CUDA")

  if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA})
    message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} "
      "expected for CUDA build, saw ${Torch_VERSION} instead.")
  endif()
elseif(HIP_FOUND)
  set(VLLM_GPU_LANG "HIP")

  # Importing torch recognizes and sets up some HIP/ROCm configuration but does
  # not let cmake recognize .hip files. In order to get cmake to understand the
  # .hip extension automatically, HIP must be enabled explicitly.
  enable_language(HIP)

118
119
120
  # ROCm 5.X and 6.X
  if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
      NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
121
    message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
122
      "expected for ROCm build, saw ${Torch_VERSION} instead.")
bnellnm's avatar
bnellnm committed
123
124
125
126
127
  endif()
else()
  message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif()

128
129

if(VLLM_GPU_LANG STREQUAL "CUDA")
130
  #
131
  # For cuda we want to be able to control which architectures we compile for on
132
  # a per-file basis in order to cut down on compile time. So here we extract
133
  # the set of architectures we want to compile for and remove the from the
134
135
  # CMAKE_CUDA_FLAGS so that they are not applied globally.
  #
136
137
138
  clear_cuda_arches(CUDA_ARCH_FLAGS)
  extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
  message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
139
140
  # Filter the target architectures by the supported supported archs
  # since for some files we will build for all CUDA_ARCHS.
141
  cuda_archs_loose_intersection(CUDA_ARCHS
142
143
144
145
146
147
148
149
150
151
152
    "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
  message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
else()
  #
  # For other GPU targets override the GPU architectures detected by cmake/torch
  # and filter them by the supported versions for the current language.
  # The final set of arches is stored in `VLLM_GPU_ARCHES`.
  #
  override_gpu_arches(VLLM_GPU_ARCHES
    ${VLLM_GPU_LANG}
    "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
153
154
endif()

bnellnm's avatar
bnellnm committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#
# Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`.
# The final set of arches is stored in `VLLM_GPU_FLAGS`.
#
get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG})

#
# Set nvcc parallelism.
#
if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
  list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()

169
170
171

#
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
172
173
174
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
# Each dependency that produces build artifacts should override its BINARY_DIR to avoid
# conflicts between build types. It should instead be set to ${CMAKE_BINARY_DIR}/<dependency>.
175
#
176
include(FetchContent)
177
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
178
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
179

bnellnm's avatar
bnellnm committed
180
#
181
# Define other extension targets
bnellnm's avatar
bnellnm committed
182
183
184
185
186
187
188
189
#

#
# _C extension
#

set(VLLM_EXT_SRC
  "csrc/cache_kernels.cu"
190
191
  "csrc/attention/paged_attention_v1.cu"
  "csrc/attention/paged_attention_v2.cu"
bnellnm's avatar
bnellnm committed
192
193
194
  "csrc/pos_encoding_kernels.cu"
  "csrc/activation_kernels.cu"
  "csrc/layernorm_kernels.cu"
195
  "csrc/layernorm_quant_kernels.cu"
bnellnm's avatar
bnellnm committed
196
  "csrc/quantization/gptq/q_gemm.cu"
197
  "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
198
  "csrc/quantization/fp8/common.cu"
199
  "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
200
  "csrc/quantization/gguf/gguf_kernel.cu"
bnellnm's avatar
bnellnm committed
201
  "csrc/cuda_utils_kernels.cu"
202
  "csrc/prepare_inputs/advance_step.cu"
203
  "csrc/torch_bindings.cpp")
bnellnm's avatar
bnellnm committed
204
205

if(VLLM_GPU_LANG STREQUAL "CUDA")
206
  SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
207
208

  # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
209
  set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
210

211
212
213
214
215
216
217
218
219
220
221
222
223
  # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
  if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
    set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR})
  endif()

  if(VLLM_CUTLASS_SRC_DIR)
    if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR)
      get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE)
    endif()
    message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation")
    FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR})
  else()
    FetchContent_Declare(
224
        cutlass
225
        GIT_REPOSITORY https://github.com/nvidia/cutlass.git
226
        GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
227
        GIT_PROGRESS TRUE
228
229
230
231

        # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
        # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
        # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
232
        GIT_SHALLOW FALSE
233
234
    )
  endif()
235
236
  FetchContent_MakeAvailable(cutlass)

bnellnm's avatar
bnellnm committed
237
  list(APPEND VLLM_EXT_SRC
238
239
    "csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
    "csrc/mamba/causal_conv1d/causal_conv1d.cu"
James Fleming's avatar
James Fleming committed
240
    "csrc/quantization/aqlm/gemm_kernels.cu"
bnellnm's avatar
bnellnm committed
241
    "csrc/quantization/awq/gemm_kernels.cu"
242
    "csrc/custom_all_reduce.cu"
243
    "csrc/permute_cols.cu"
244
245
246
247
    "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
    "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
    "csrc/sparse/cutlass/sparse_compressor_entry.cu"
    "csrc/cutlass_extensions/common.cpp")
248
249
250
251
252
253
254
255

  set_gencode_flags_for_srcs(
    SRCS "${VLLM_EXT_SRC}"
    CUDA_ARCHS "${CUDA_ARCHS}")

  # Only build Marlin kernels if we are building for at least some compatible archs.
  # Keep building Marlin for 9.0 as there are some group sizes and shapes that
  # are not supported by Machete yet.
256
  cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
257
  if (MARLIN_ARCHS)
258
    set(MARLIN_SRCS
259
260
261
262
263
264
265
266
267
268
269
270
271
272
       "csrc/quantization/fp8/fp8_marlin.cu"
       "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
       "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
       "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
       "csrc/quantization/gptq_marlin/gptq_marlin.cu"
       "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
       "csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
    set_gencode_flags_for_srcs(
      SRCS "${MARLIN_SRCS}"
      CUDA_ARCHS "${MARLIN_ARCHS}")
    list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
    message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
  else()
    message(STATUS "Not building Marlin kernels as no compatible archs found"
273
                   " in CUDA target architectures")
274
275
276
  endif()

  #
277
278
  # The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
  # For Hopper (c3x, i.e. CUTLASS 3.x) require
279
280
281
  # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
  cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
282
283
284
    set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
             "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
             "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
285
286
287
288
289
290
291
292
    set_gencode_flags_for_srcs(
      SRCS "${SRCS}"
      CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
    list(APPEND VLLM_EXT_SRC "${SRCS}")
    list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
    message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
  else()
    if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
293
      message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
294
                     "not >= 12.0, we recommend upgrading to CUDA 12.0 or "
295
                     "later if you intend on running FP8 sparse or quantized models on "
296
297
                     "Hopper.")
    else()
298
      message(STATUS "Not building cutlass_c3x as no compatible archs found "
299
300
                     "in CUDA target architectures")
    endif()
301

302
    # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
303
304
    # build any 3x kernels
    set(SCALED_MM_3X_ARCHS)
305
  endif()
306
307

  #
308
309
  # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
  # kernels for the remaining archs that are not already built for 3x.
310
  cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
311
    "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
  # subtract out the archs that are already built for 3x
  list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
  if (SCALED_MM_2X_ARCHS)
    set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
    set_gencode_flags_for_srcs(
      SRCS "${SRCS}"
      CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
    list(APPEND VLLM_EXT_SRC "${SRCS}")
    list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
    message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
  else()
    if (SCALED_MM_3X_ARCHS)
      message(STATUS "Not building scaled_mm_c2x as all archs are already built"
                     " for and covered by scaled_mm_c3x")
    else()
      message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
                    "in CUDA target architectures")
    endif()
330
  endif()
331

332

333
  #
334
  # Machete kernels
335

336
  # The machete kernels only work on hopper and require CUDA 12.0 or later.
337
338
339
  # Only build Machete kernels if we are building for something compatible with sm90a
  cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
340
    #
341
    # For the Machete kernels we automatically generate sources for various
342
343
    # preselected input type pairs and schedules.
    # Generate sources:
344
    set(MACHETE_GEN_SCRIPT
345
346
347
348
349
350
351
352
353
      ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py)
    file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH)

    message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}")
    message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}")

    if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH}
        OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
      execute_process(
354
355
        COMMAND ${CMAKE_COMMAND} -E env
        PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
356
357
358
359
360
361
362
363
364
          ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
        RESULT_VARIABLE machete_generation_result
        OUTPUT_VARIABLE machete_generation_output
        OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
        ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
      )

      if (NOT machete_generation_result EQUAL 0)
        message(FATAL_ERROR "Machete generation failed."
365
                            " Result: \"${machete_generation_result}\""
366
367
368
                            "\nCheck the log for details: "
                            "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
      else()
369
        set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH}
370
371
372
            CACHE STRING "Last run machete generate script hash" FORCE)
        message(STATUS "Machete generation completed successfully.")
      endif()
373
    else()
374
      message(STATUS "Machete generation script has not changed, skipping generation.")
375
    endif()
376

377
378
379
    # Add machete generated sources
    file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
    list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
380

381
382
383
384
385
386
387
388
389
390
    # forward compatible
    set_gencode_flags_for_srcs(
      SRCS "${MACHETE_GEN_SOURCES}"
      CUDA_ARCHS "${MACHETE_ARCHS}")

    list(APPEND VLLM_EXT_SRC
      csrc/quantization/machete/machete_pytorch.cu)

    message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
  else()
391
    if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
392
393
394
395
396
397
398
399
400
        AND MACHETE_ARCHS)
      message(STATUS "Not building Machete kernels as CUDA Compiler version is "
                     "not >= 12.0, we recommend upgrading to CUDA 12.0 or "
                     "later if you intend on running w4a16 quantized models on "
                     "Hopper.")
    else()
      message(STATUS "Not building Machete kernels as no compatible archs "
                     "found in CUDA target architectures")
    endif()
401
  endif()
402
# if CUDA endif
bnellnm's avatar
bnellnm committed
403
404
endif()

405
message(STATUS "Enabling C extension.")
bnellnm's avatar
bnellnm committed
406
407
408
409
410
411
412
define_gpu_extension_target(
  _C
  DESTINATION vllm
  LANGUAGE ${VLLM_GPU_LANG}
  SOURCES ${VLLM_EXT_SRC}
  COMPILE_FLAGS ${VLLM_GPU_FLAGS}
  ARCHITECTURES ${VLLM_GPU_ARCHES}
413
  INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
414
  USE_SABI 3
bnellnm's avatar
bnellnm committed
415
416
  WITH_SOABI)

417
418
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
419
420
421
422
# driver API. This causes problems when linking with earlier versions of CUDA.
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

bnellnm's avatar
bnellnm committed
423
424
425
426
427
#
# _moe_C extension
#

set(VLLM_MOE_EXT_SRC
428
  "csrc/moe/torch_bindings.cpp"
429
  "csrc/moe/moe_align_sum_kernels.cu"
430
  "csrc/moe/topk_softmax_kernels.cu")
bnellnm's avatar
bnellnm committed
431

432
433
434
435
set_gencode_flags_for_srcs(
  SRCS "${VLLM_MOE_EXT_SRC}"
  CUDA_ARCHS "${CUDA_ARCHS}")

436
if(VLLM_GPU_LANG STREQUAL "CUDA")
437
  cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
438
439
440
441
442
443
444
  if (MARLIN_MOE_ARCHS)
    set(MARLIN_MOE_SRC
        "csrc/moe/marlin_kernels/marlin_moe_kernel.h"
        "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
        "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
        "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
        "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
445
446
        "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
        "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
447
448
449
450
451
452
453
454
455
456
        "csrc/moe/marlin_moe_ops.cu")

    set_gencode_flags_for_srcs(
      SRCS "${MARLIN_MOE_SRC}"
      CUDA_ARCHS "${MARLIN_MOE_ARCHS}")

    list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
    message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
  else()
    message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
457
                   " in CUDA target architectures")
458
  endif()
459
460
endif()

461
message(STATUS "Enabling moe extension.")
bnellnm's avatar
bnellnm committed
462
463
464
465
466
467
468
define_gpu_extension_target(
  _moe_C
  DESTINATION vllm
  LANGUAGE ${VLLM_GPU_LANG}
  SOURCES ${VLLM_MOE_EXT_SRC}
  COMPILE_FLAGS ${VLLM_GPU_FLAGS}
  ARCHITECTURES ${VLLM_GPU_ARCHES}
469
  USE_SABI 3
bnellnm's avatar
bnellnm committed
470
471
  WITH_SOABI)

472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
if(VLLM_GPU_LANG STREQUAL "HIP")
  #
  # _rocm_C extension
  #
  set(VLLM_ROCM_EXT_SRC
    "csrc/rocm/torch_bindings.cpp"
    "csrc/rocm/attention.cu")

  define_gpu_extension_target(
    _rocm_C
    DESTINATION vllm
    LANGUAGE ${VLLM_GPU_LANG}
    SOURCES ${VLLM_ROCM_EXT_SRC}
    COMPILE_FLAGS ${VLLM_GPU_FLAGS}
    ARCHITECTURES ${VLLM_GPU_ARCHES}
    USE_SABI 3
    WITH_SOABI)
endif()

491
492
493
494
# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
  return()
endif ()
bnellnm's avatar
bnellnm committed
495

496
497
498
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
# arches in the CUDA case (and instead set the gencodes on a per file basis)
499
500
501
502
503
504
505
506
# we need to manually set VLLM_GPU_ARCHES here.
if(VLLM_GPU_LANG STREQUAL "CUDA")
  foreach(_ARCH ${CUDA_ARCHS})
    string(REPLACE "." "" _ARCH "${_ARCH}")
    list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
  endforeach()
endif()

507
508
509
510
511
512
513
514
515
516
#
# Build vLLM flash attention from source
#
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# If no component is specified, vllm-flash-attn is still installed.
bnellnm's avatar
bnellnm committed
517

518
519
520
521
522
523
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
# This is to enable local development of vllm-flash-attn within vLLM.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
  set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
524
endif()
525

526
527
528
529
530
531
if(VLLM_FLASH_ATTN_SRC_DIR)
  FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
else()
  FetchContent_Declare(
          vllm-flash-attn
          GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
532
          GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
533
          GIT_PROGRESS TRUE
534
535
          # Don't share the vllm-flash-attn build between build types
          BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
536
  )
bnellnm's avatar
bnellnm committed
537
endif()
538
539
540
541

# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)

542
543
544
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)

# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")

# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)

# Copy over the vllm-flash-attn python files
install(
        DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
        DESTINATION vllm/vllm_flash_attn
        COMPONENT vllm_flash_attn_c
        FILES_MATCHING PATTERN "*.py"
)

# Nothing after vllm-flash-attn, see comment about macros above