Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d6770d1f
Unverified
Commit
d6770d1f
authored
Sep 10, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 10, 2023
Browse files
Update setup.py (#1006)
parent
b9cecc26
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
14 deletions
+35
-14
setup.py
setup.py
+35
-14
No files found.
setup.py
View file @
d6770d1f
...
@@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
...
@@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if
CUDA_HOME
is
None
:
if
CUDA_HOME
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Cannot find CUDA_HOME. CUDA must be available to build the package."
)
"Cannot find CUDA_HOME. CUDA must be available to build the package."
)
def
get_nvcc_cuda_version
(
cuda_dir
:
str
)
->
Version
:
def
get_nvcc_cuda_version
(
cuda_dir
:
str
)
->
Version
:
...
@@ -54,7 +54,8 @@ if nvcc_cuda_version < Version("11.0"):
...
@@ -54,7 +54,8 @@ if nvcc_cuda_version < Version("11.0"):
raise
RuntimeError
(
"CUDA 11.0 or higher is required to build the package."
)
raise
RuntimeError
(
"CUDA 11.0 or higher is required to build the package."
)
if
86
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.1"
):
if
86
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.1"
):
raise
RuntimeError
(
raise
RuntimeError
(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
)
if
89
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.8"
):
if
89
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.8"
):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# However, GPUs with compute capability 8.9 can also run the code generated by
...
@@ -65,7 +66,8 @@ if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
...
@@ -65,7 +66,8 @@ if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
compute_capabilities
.
add
(
80
)
compute_capabilities
.
add
(
80
)
if
90
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.8"
):
if
90
in
compute_capabilities
and
nvcc_cuda_version
<
Version
(
"11.8"
):
raise
RuntimeError
(
raise
RuntimeError
(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
)
# If no GPU is available, add all supported compute capabilities.
# If no GPU is available, add all supported compute capabilities.
if
not
compute_capabilities
:
if
not
compute_capabilities
:
...
@@ -78,7 +80,9 @@ if not compute_capabilities:
...
@@ -78,7 +80,9 @@ if not compute_capabilities:
# Add target compute capabilities to NVCC flags.
# Add target compute capabilities to NVCC flags.
for
capability
in
compute_capabilities
:
for
capability
in
compute_capabilities
:
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
capability
}
,code=sm_
{
capability
}
"
]
NVCC_FLAGS
+=
[
"-gencode"
,
f
"arch=compute_
{
capability
}
,code=sm_
{
capability
}
"
]
# Use NVCC threads to parallelize the build.
# Use NVCC threads to parallelize the build.
if
nvcc_cuda_version
>=
Version
(
"11.2"
):
if
nvcc_cuda_version
>=
Version
(
"11.2"
):
...
@@ -91,7 +95,10 @@ ext_modules = []
...
@@ -91,7 +95,10 @@ ext_modules = []
cache_extension
=
CUDAExtension
(
cache_extension
=
CUDAExtension
(
name
=
"vllm.cache_ops"
,
name
=
"vllm.cache_ops"
,
sources
=
[
"csrc/cache.cpp"
,
"csrc/cache_kernels.cu"
],
sources
=
[
"csrc/cache.cpp"
,
"csrc/cache_kernels.cu"
],
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
},
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
)
)
ext_modules
.
append
(
cache_extension
)
ext_modules
.
append
(
cache_extension
)
...
@@ -99,7 +106,10 @@ ext_modules.append(cache_extension)
...
@@ -99,7 +106,10 @@ ext_modules.append(cache_extension)
attention_extension
=
CUDAExtension
(
attention_extension
=
CUDAExtension
(
name
=
"vllm.attention_ops"
,
name
=
"vllm.attention_ops"
,
sources
=
[
"csrc/attention.cpp"
,
"csrc/attention/attention_kernels.cu"
],
sources
=
[
"csrc/attention.cpp"
,
"csrc/attention/attention_kernels.cu"
],
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
},
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
)
)
ext_modules
.
append
(
attention_extension
)
ext_modules
.
append
(
attention_extension
)
...
@@ -107,7 +117,10 @@ ext_modules.append(attention_extension)
...
@@ -107,7 +117,10 @@ ext_modules.append(attention_extension)
positional_encoding_extension
=
CUDAExtension
(
positional_encoding_extension
=
CUDAExtension
(
name
=
"vllm.pos_encoding_ops"
,
name
=
"vllm.pos_encoding_ops"
,
sources
=
[
"csrc/pos_encoding.cpp"
,
"csrc/pos_encoding_kernels.cu"
],
sources
=
[
"csrc/pos_encoding.cpp"
,
"csrc/pos_encoding_kernels.cu"
],
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
},
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
)
)
ext_modules
.
append
(
positional_encoding_extension
)
ext_modules
.
append
(
positional_encoding_extension
)
...
@@ -115,7 +128,10 @@ ext_modules.append(positional_encoding_extension)
...
@@ -115,7 +128,10 @@ ext_modules.append(positional_encoding_extension)
layernorm_extension
=
CUDAExtension
(
layernorm_extension
=
CUDAExtension
(
name
=
"vllm.layernorm_ops"
,
name
=
"vllm.layernorm_ops"
,
sources
=
[
"csrc/layernorm.cpp"
,
"csrc/layernorm_kernels.cu"
],
sources
=
[
"csrc/layernorm.cpp"
,
"csrc/layernorm_kernels.cu"
],
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
},
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
)
)
ext_modules
.
append
(
layernorm_extension
)
ext_modules
.
append
(
layernorm_extension
)
...
@@ -123,7 +139,10 @@ ext_modules.append(layernorm_extension)
...
@@ -123,7 +139,10 @@ ext_modules.append(layernorm_extension)
activation_extension
=
CUDAExtension
(
activation_extension
=
CUDAExtension
(
name
=
"vllm.activation_ops"
,
name
=
"vllm.activation_ops"
,
sources
=
[
"csrc/activation.cpp"
,
"csrc/activation_kernels.cu"
],
sources
=
[
"csrc/activation.cpp"
,
"csrc/activation_kernels.cu"
],
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
},
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
},
)
)
ext_modules
.
append
(
activation_extension
)
ext_modules
.
append
(
activation_extension
)
...
@@ -138,8 +157,8 @@ def find_version(filepath: str):
...
@@ -138,8 +157,8 @@ def find_version(filepath: str):
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
"""
"""
with
open
(
filepath
)
as
fp
:
with
open
(
filepath
)
as
fp
:
version_match
=
re
.
search
(
version_match
=
re
.
search
(
r
"^__version__ = ['\"]([^'\"]*)['\"]"
,
r
"^__version__ = ['\"]([^'\"]*)['\"]"
,
fp
.
read
(),
re
.
M
)
fp
.
read
(),
re
.
M
)
if
version_match
:
if
version_match
:
return
version_match
.
group
(
1
)
return
version_match
.
group
(
1
)
raise
RuntimeError
(
"Unable to find version string."
)
raise
RuntimeError
(
"Unable to find version string."
)
...
@@ -162,7 +181,8 @@ setuptools.setup(
...
@@ -162,7 +181,8 @@ setuptools.setup(
version
=
find_version
(
get_path
(
"vllm"
,
"__init__.py"
)),
version
=
find_version
(
get_path
(
"vllm"
,
"__init__.py"
)),
author
=
"vLLM Team"
,
author
=
"vLLM Team"
,
license
=
"Apache 2.0"
,
license
=
"Apache 2.0"
,
description
=
"A high-throughput and memory-efficient inference and serving engine for LLMs"
,
description
=
(
"A high-throughput and memory-efficient inference and "
"serving engine for LLMs"
),
long_description
=
read_readme
(),
long_description
=
read_readme
(),
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/vllm-project/vllm"
,
url
=
"https://github.com/vllm-project/vllm"
,
...
@@ -174,11 +194,12 @@ setuptools.setup(
...
@@ -174,11 +194,12 @@ setuptools.setup(
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"License :: OSI Approved :: Apache Software License"
,
"License :: OSI Approved :: Apache Software License"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
],
],
packages
=
setuptools
.
find_packages
(
packages
=
setuptools
.
find_packages
(
exclude
=
(
"benchmarks"
,
"csrc"
,
"docs"
,
exclude
=
(
"assets"
,
"benchmarks"
,
"csrc"
,
"docs"
,
"examples"
,
"tests"
)),
"examples"
,
"tests"
)),
python_requires
=
">=3.8"
,
python_requires
=
">=3.8"
,
install_requires
=
get_requirements
(),
install_requires
=
get_requirements
(),
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment