Unverified Commit 1cb0cc29 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[FIX] Make `flash_attn` optional (#3269)

parent 99c3cfb8
...@@ -184,6 +184,3 @@ _build/ ...@@ -184,6 +184,3 @@ _build/
# Benchmark dataset # Benchmark dataset
*.json *.json
# Third-party Python packages.
vllm/thirdparty_files/
...@@ -3,7 +3,6 @@ import io ...@@ -3,7 +3,6 @@ import io
import os import os
import re import re
import subprocess import subprocess
import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import List, Set from typing import List, Set
...@@ -15,8 +14,6 @@ import torch.utils.cpp_extension as torch_cpp_ext ...@@ -15,8 +14,6 @@ import torch.utils.cpp_extension as torch_cpp_ext
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"
# If you are developing the C++ backend of vLLM, consider building vLLM with # If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds. # `python setup.py develop` since it will give you incremental builds.
...@@ -327,46 +324,8 @@ if _is_cuda(): ...@@ -327,46 +324,8 @@ if _is_cuda():
"nvcc": NVCC_FLAGS_PUNICA, "nvcc": NVCC_FLAGS_PUNICA,
}, },
)) ))
elif _is_neuron():
# Download the FlashAttention package. neuronxcc_version = get_neuronxcc_version()
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version = "2.5.6"
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
f"--target={install_dir}",
"einops", # Dependency of flash-attn.
f"flash-attn=={flash_attn_version}",
"--no-dependencies", # Required to avoid re-installing torch.
],
env=dict(os.environ, CC="gcc"),
)
# Copy the FlashAttention package into the vLLM package after build.
class build_ext(BuildExtension):
def run(self):
super().run()
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
self.copy_tree(install_dir, target_dir)
class BinaryDistribution(setuptools.Distribution):
def has_ext_modules(self):
return True
else:
build_ext = BuildExtension
BinaryDistribution = setuptools.Distribution
if _is_neuron():
neuronxcc_version = get_neuronxcc_version()
vllm_extension_sources = [ vllm_extension_sources = [
"csrc/cache_kernels.cu", "csrc/cache_kernels.cu",
...@@ -509,7 +468,6 @@ setuptools.setup( ...@@ -509,7 +468,6 @@ setuptools.setup(
python_requires=">=3.8", python_requires=">=3.8",
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": build_ext} if not _is_neuron() else {}, cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
distclass=BinaryDistribution,
package_data=package_data, package_data=package_data,
) )
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs""" """vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11 from vllm.engine.async_llm_engine import AsyncLLMEngine
def _configure_system(): from vllm.engine.llm_engine import LLMEngine
import os from vllm.engine.ray_utils import initialize_cluster
import sys from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
# Importing flash-attn. from vllm.sampling_params import SamplingParams
thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)),
"thirdparty_files")
sys.path.insert(0, thirdparty_files)
_configure_system()
# Delete configuration function.
del _configure_system
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402
from vllm.engine.llm_engine import LLMEngine # noqa: E402
from vllm.engine.ray_utils import initialize_cluster # noqa: E402
from vllm.entrypoints.llm import LLM # noqa: E402
from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402
from vllm.sampling_params import SamplingParams # noqa: E402
__version__ = "0.3.3" __version__ = "0.3.3"
......
"""Attention layer.""" """Attention layer."""
from functools import lru_cache
from typing import List, Optional from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip from vllm.utils import is_hip
logger = init_logger(__name__)
class Attention(nn.Module): class Attention(nn.Module):
"""Attention layer. """Attention layer.
...@@ -30,17 +34,12 @@ class Attention(nn.Module): ...@@ -30,17 +34,12 @@ class Attention(nn.Module):
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and if _use_flash_attn():
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
self.backend = FlashAttentionBackend(num_heads, head_size, scale, self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes, num_kv_heads, alibi_slopes,
sliding_window) sliding_window)
else: else:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
# Or FP32 on any GPU.
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
self.backend = XFormersBackend(num_heads, head_size, scale, self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes, num_kv_heads, alibi_slopes,
...@@ -57,3 +56,29 @@ class Attention(nn.Module): ...@@ -57,3 +56,29 @@ class Attention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
return self.backend.forward(query, key, value, key_cache, value_cache, return self.backend.forward(query, key, value, key_cache, value_cache,
input_metadata) input_metadata)
@lru_cache(maxsize=1)
def _use_flash_attn() -> bool:
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return False
if is_hip():
# AMD GPUs.
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("flash_attn is not supported on Turing or older GPUs. "
"Using xformers backend.")
return False
if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
logger.info(
"flash_attn only supports torch.float16 or torch.bfloat16. "
"Using xformers backend.")
return False
logger.info("Using flash_attn backend.")
return True
"""Attention layer with Flash and PagedAttention.""" """Attention layer with Flash and PagedAttention."""
from typing import List, Optional from typing import List, Optional
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from flash_attn import flash_attn_func from flash_attn import flash_attn_func
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment