Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1a8bfd92
Unverified
Commit
1a8bfd92
authored
Jun 12, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 12, 2024
Browse files
[Hardware] Initial TPU integration (#5292)
parent
847cdcca
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
599 additions
and
28 deletions
+599
-28
Dockerfile.tpu
Dockerfile.tpu
+19
-0
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+1
-1
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+1
-1
docs/source/getting_started/tpu-installation.rst
docs/source/getting_started/tpu-installation.rst
+75
-0
docs/source/index.rst
docs/source/index.rst
+2
-1
requirements-tpu.txt
requirements-tpu.txt
+7
-0
setup.py
setup.py
+17
-5
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+232
-0
vllm/attention/selector.py
vllm/attention/selector.py
+11
-2
vllm/config.py
vllm/config.py
+5
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+3
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-0
vllm/envs.py
vllm/envs.py
+6
-0
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+101
-0
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+3
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+74
-3
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+21
-6
vllm/utils.py
vllm/utils.py
+14
-0
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+3
-6
No files found.
Dockerfile.tpu
0 → 100644
View file @
1a8bfd92
ARG NIGHTLY_DATE="20240601"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE
WORKDIR /workspace
COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu"
# Install aiohttp separately to avoid build errors.
RUN pip install aiohttp
# Install the TPU and Pallas dependencies.
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Build vLLM.
RUN cd /workspace/vllm && python setup.py develop
CMD ["/bin/bash"]
benchmarks/benchmark_latency.py
View file @
1a8bfd92
...
@@ -189,7 +189,7 @@ if __name__ == '__main__':
...
@@ -189,7 +189,7 @@ if __name__ == '__main__':
"--device"
,
"--device"
,
type
=
str
,
type
=
str
,
default
=
"cuda"
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"cpu"
],
choices
=
[
"cuda"
,
"cpu"
,
"tpu"
],
help
=
'device type for vLLM execution, supporting CUDA and CPU.'
)
help
=
'device type for vLLM execution, supporting CUDA and CPU.'
)
parser
.
add_argument
(
'--block-size'
,
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
type
=
int
,
...
...
benchmarks/benchmark_throughput.py
View file @
1a8bfd92
...
@@ -346,7 +346,7 @@ if __name__ == "__main__":
...
@@ -346,7 +346,7 @@ if __name__ == "__main__":
"--device"
,
"--device"
,
type
=
str
,
type
=
str
,
default
=
"cuda"
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"cpu"
],
choices
=
[
"cuda"
,
"cpu"
,
"tpu"
],
help
=
'device type for vLLM execution, supporting CUDA and CPU.'
)
help
=
'device type for vLLM execution, supporting CUDA and CPU.'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-prefix-caching"
,
"--enable-prefix-caching"
,
...
...
docs/source/getting_started/tpu-installation.rst
0 → 100644
View file @
1a8bfd92
.. _installation_tpu:
Installation with TPU
=====================
vLLM supports Google Cloud TPUs using PyTorch XLA.
Requirements
------------
* Google Cloud TPU VM (single host)
* TPU versions: v5e, v5p, v4
* Python: 3.10
Installation options:
1. :ref:`Build a docker image with Dockerfile <build_docker_tpu>`.
2. :ref:`Build from source <build_from_source_tpu>`.
.. _build_docker_tpu:
Build a docker image with :code:`Dockerfile.tpu`
------------------------------------------------
`Dockerfile.tpu <https://github.com/vllm-project/vllm/blob/main/Dockerfile.tpu>`_ is provided to build a docker image with TPU support.
.. code-block:: console
$ docker build -f Dockerfile.tpu -t vllm-tpu .
You can run the docker image with the following command:
.. code-block:: console
$ # Make sure to add `--privileged --net host --shm-size=16G`.
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
.. _build_from_source_tpu:
Build from source
-----------------
You can also build and install the TPU backend from source.
First, install the dependencies:
.. code-block:: console
$ # (Recommended) Create a new conda environment.
$ conda create -n myenv python=3.10 -y
$ conda activate myenv
$ # Clean up the existing torch and torch-xla packages.
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240601"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ # Install JAX and Pallas.
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install packaging aiohttp
Next, build vLLM from source. This will only take a few seconds:
.. code-block:: console
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
docs/source/index.rst
View file @
1a8bfd92
...
@@ -63,8 +63,9 @@ Documentation
...
@@ -63,8 +63,9 @@ Documentation
getting_started/installation
getting_started/installation
getting_started/amd-installation
getting_started/amd-installation
getting_started/neuron-installation
getting_started/cpu-installation
getting_started/cpu-installation
getting_started/neuron-installation
getting_started/tpu-installation
getting_started/quickstart
getting_started/quickstart
getting_started/debugging
getting_started/debugging
getting_started/examples/examples_index
getting_started/examples/examples_index
...
...
requirements-tpu.txt
0 → 100644
View file @
1a8bfd92
# Common dependencies
-r requirements-common.txt
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
triton # To avoid import errors
setup.py
View file @
1a8bfd92
...
@@ -206,9 +206,9 @@ class cmake_build_ext(build_ext):
...
@@ -206,9 +206,9 @@ class cmake_build_ext(build_ext):
def
_is_cuda
()
->
bool
:
def
_is_cuda
()
->
bool
:
return
VLLM_TARGET_DEVICE
==
"cuda"
\
has_cuda
=
torch
.
version
.
cuda
is
not
None
and
torch
.
version
.
cuda
is
not
None
\
return
(
VLLM_TARGET_DEVICE
==
"cuda"
and
has_cuda
and
not
_is_neuron
()
and
not
(
_is_neuron
()
or
_is_tpu
()))
def
_is_hip
()
->
bool
:
def
_is_hip
()
->
bool
:
...
@@ -225,10 +225,18 @@ def _is_neuron() -> bool:
...
@@ -225,10 +225,18 @@ def _is_neuron() -> bool:
return
torch_neuronx_installed
or
VLLM_TARGET_DEVICE
==
"neuron"
return
torch_neuronx_installed
or
VLLM_TARGET_DEVICE
==
"neuron"
def
_is_tpu
()
->
bool
:
return
VLLM_TARGET_DEVICE
==
"tpu"
def
_is_cpu
()
->
bool
:
def
_is_cpu
()
->
bool
:
return
VLLM_TARGET_DEVICE
==
"cpu"
return
VLLM_TARGET_DEVICE
==
"cpu"
def
_build_custom_ops
()
->
bool
:
return
_is_cuda
()
or
_is_hip
()
or
_is_cpu
()
def
_install_punica
()
->
bool
:
def
_install_punica
()
->
bool
:
return
envs
.
VLLM_INSTALL_PUNICA_KERNELS
return
envs
.
VLLM_INSTALL_PUNICA_KERNELS
...
@@ -325,6 +333,8 @@ def get_vllm_version() -> str:
...
@@ -325,6 +333,8 @@ def get_vllm_version() -> str:
if
neuron_version
!=
MAIN_CUDA_VERSION
:
if
neuron_version
!=
MAIN_CUDA_VERSION
:
neuron_version_str
=
neuron_version
.
replace
(
"."
,
""
)[:
3
]
neuron_version_str
=
neuron_version
.
replace
(
"."
,
""
)[:
3
]
version
+=
f
"+neuron
{
neuron_version_str
}
"
version
+=
f
"+neuron
{
neuron_version_str
}
"
elif
_is_tpu
():
version
+=
"+tpu"
elif
_is_cpu
():
elif
_is_cpu
():
version
+=
"+cpu"
version
+=
"+cpu"
else
:
else
:
...
@@ -372,6 +382,8 @@ def get_requirements() -> List[str]:
...
@@ -372,6 +382,8 @@ def get_requirements() -> List[str]:
requirements
=
_read_requirements
(
"requirements-rocm.txt"
)
requirements
=
_read_requirements
(
"requirements-rocm.txt"
)
elif
_is_neuron
():
elif
_is_neuron
():
requirements
=
_read_requirements
(
"requirements-neuron.txt"
)
requirements
=
_read_requirements
(
"requirements-neuron.txt"
)
elif
_is_tpu
():
requirements
=
_read_requirements
(
"requirements-tpu.txt"
)
elif
_is_cpu
():
elif
_is_cpu
():
requirements
=
_read_requirements
(
"requirements-cpu.txt"
)
requirements
=
_read_requirements
(
"requirements-cpu.txt"
)
else
:
else
:
...
@@ -385,7 +397,7 @@ ext_modules = []
...
@@ -385,7 +397,7 @@ ext_modules = []
if
_is_cuda
()
or
_is_hip
():
if
_is_cuda
()
or
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
if
not
_is_neuron
():
if
_build_custom_ops
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._C"
))
if
_install_punica
():
if
_install_punica
():
...
@@ -428,6 +440,6 @@ setup(
...
@@ -428,6 +440,6 @@ setup(
extras_require
=
{
extras_require
=
{
"tensorizer"
:
[
"tensorizer>=2.9.0"
],
"tensorizer"
:
[
"tensorizer>=2.9.0"
],
},
},
cmdclass
=
{
"build_ext"
:
cmake_build_ext
}
if
not
_is_neuron
()
else
{},
cmdclass
=
{
"build_ext"
:
cmake_build_ext
}
if
_build_custom_ops
()
else
{},
package_data
=
package_data
,
package_data
=
package_data
,
)
)
vllm/attention/backends/pallas.py
0 → 100644
View file @
1a8bfd92
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
import
torch_xla.experimental.dynamo_set_buffer_donor
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
class
PallasAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"PallasAttentionBackendImpl"
]:
return
PallasAttentionBackendImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"PallasMetadata"
:
return
PallasMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_kv_heads
,
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
raise
NotImplementedError
(
"swap_blocks is not implemented."
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
# TODO(woosuk): Implement this.
raise
NotImplementedError
(
"copy_blocks is not implemented."
)
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
]
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
assert
self
.
num_decode_tokens
==
0
assert
self
.
block_tables
is
None
assert
self
.
context_lens
is
None
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
block_tables
is
not
None
assert
self
.
context_lens
is
not
None
return
self
class
PallasAttentionBackendImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
head_size
%
128
!=
0
:
raise
NotImplementedError
(
"Head size must be a multiple of 128."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
self
.
megacore_mode
=
None
tpu_type
=
torch_xla
.
tpu
.
get_tpu_env
()[
"TYPE"
].
lower
()
if
not
tpu_type
.
endswith
(
"lite"
):
if
self
.
num_kv_heads
%
2
==
0
:
self
.
megacore_mode
=
"kv_head"
else
:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self
.
megacore_mode
=
"batch"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]],
attn_metadata
:
PallasMetadata
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert
kv_scale
==
1.0
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
[
0
]
is
not
None
:
slot_mapping
=
attn_metadata
.
slot_mapping
key_cache
,
value_cache
=
kv_cache
write_to_kv_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
query
=
query
*
self
.
scale
if
attn_metadata
.
num_prefills
>
0
:
assert
seq_len
%
16
==
0
,
(
"Pallas FlashAttention kernel requires seq_len to be a "
f
"multiple of 16 but got
{
seq_len
}
"
)
# Handle GQA/MQA.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output
=
torch
.
ops
.
xla
.
flash_attention
(
query
.
permute
(
0
,
2
,
1
,
3
),
key
.
permute
(
0
,
2
,
1
,
3
),
value
.
permute
(
0
,
2
,
1
,
3
),
True
,
)
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
else
:
# Decoding run.
assert
kv_cache
is
not
None
pages_per_compute_block
=
16
# TODO(woosuk): Tune this value.
if
self
.
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
megacore_mode
=
None
else
:
megacore_mode
=
self
.
megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if
megacore_mode
is
not
None
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
.
squeeze
(
dim
=
1
),
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
megacore_mode
=
megacore_mode
,
)
else
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
.
squeeze
(
dim
=
1
),
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
)
# Reshape the output tensor.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
key
=
key
.
flatten
(
0
,
2
)
value
=
value
.
flatten
(
0
,
2
)
key_cache
=
key_cache
.
flatten
(
0
,
2
)
value_cache
=
value_cache
.
flatten
(
0
,
2
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
vllm/attention/selector.py
View file @
1a8bfd92
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_cpu
,
is_hip
from
vllm.utils
import
is_cpu
,
is_hip
,
is_tpu
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -18,6 +18,7 @@ class _Backend(enum.Enum):
...
@@ -18,6 +18,7 @@ class _Backend(enum.Enum):
ROCM_FLASH
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
...
@@ -66,6 +67,10 @@ def get_attn_backend(
...
@@ -66,6 +67,10 @@ def get_attn_backend(
"Please make sure --enforce-eager is set."
)
"Please make sure --enforce-eager is set."
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
return
FlashInferBackend
elif
backend
==
_Backend
.
PALLAS
:
logger
.
info
(
"Using Pallas backend."
)
from
vllm.attention.backends.pallas
import
PallasAttentionBackend
return
PallasAttentionBackend
else
:
else
:
raise
ValueError
(
"Invalid attention backend."
)
raise
ValueError
(
"Invalid attention backend."
)
...
@@ -80,7 +85,6 @@ def which_attn_to_use(
...
@@ -80,7 +85,6 @@ def which_attn_to_use(
block_size
:
int
,
block_size
:
int
,
)
->
_Backend
:
)
->
_Backend
:
"""Returns which flash attention backend to use."""
"""Returns which flash attention backend to use."""
# Default case.
# Default case.
selected_backend
=
_Backend
.
FLASH_ATTN
selected_backend
=
_Backend
.
FLASH_ATTN
...
@@ -100,6 +104,11 @@ def which_attn_to_use(
...
@@ -100,6 +104,11 @@ def which_attn_to_use(
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
return
_Backend
.
TORCH_SDPA
if
is_tpu
():
if
selected_backend
!=
_Backend
.
PALLAS
:
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
return
_Backend
.
PALLAS
if
is_hip
():
if
is_hip
():
# AMD GPUs.
# AMD GPUs.
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
...
...
vllm/config.py
View file @
1a8bfd92
...
@@ -11,7 +11,7 @@ from vllm.logger import init_logger
...
@@ -11,7 +11,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
from
vllm.utils
import
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
,
is_tpu
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
...
@@ -748,6 +748,8 @@ class DeviceConfig:
...
@@ -748,6 +748,8 @@ class DeviceConfig:
# Automated device type detection
# Automated device type detection
if
is_neuron
():
if
is_neuron
():
self
.
device_type
=
"neuron"
self
.
device_type
=
"neuron"
elif
is_tpu
():
self
.
device_type
=
"tpu"
elif
is_cpu
():
elif
is_cpu
():
self
.
device_type
=
"cpu"
self
.
device_type
=
"cpu"
else
:
else
:
...
@@ -761,6 +763,8 @@ class DeviceConfig:
...
@@ -761,6 +763,8 @@ class DeviceConfig:
# Some device types require processing inputs on CPU
# Some device types require processing inputs on CPU
if
self
.
device_type
in
[
"neuron"
]:
if
self
.
device_type
in
[
"neuron"
]:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
device
=
torch
.
device
(
"cpu"
)
elif
self
.
device_type
in
[
"tpu"
]:
self
.
device
=
None
else
:
else
:
# Set device with device type
# Set device with device type
self
.
device
=
torch
.
device
(
self
.
device_type
)
self
.
device
=
torch
.
device
(
self
.
device_type
)
...
...
vllm/engine/arg_utils.py
View file @
1a8bfd92
...
@@ -504,7 +504,7 @@ class EngineArgs:
...
@@ -504,7 +504,7 @@ class EngineArgs:
parser
.
add_argument
(
"--device"
,
parser
.
add_argument
(
"--device"
,
type
=
str
,
type
=
str
,
default
=
EngineArgs
.
device
,
default
=
EngineArgs
.
device
,
choices
=
[
"auto"
,
"cuda"
,
"neuron"
,
"cpu"
],
choices
=
[
"auto"
,
"cuda"
,
"neuron"
,
"cpu"
,
"tpu"
],
help
=
'Device type for vLLM execution.'
)
help
=
'Device type for vLLM execution.'
)
# Related to Vision-language models such as llava
# Related to Vision-language models such as llava
...
...
vllm/engine/async_llm_engine.py
View file @
1a8bfd92
...
@@ -375,6 +375,9 @@ class AsyncLLMEngine:
...
@@ -375,6 +375,9 @@ class AsyncLLMEngine:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
from
vllm.executor.tpu_executor
import
TPUExecutorAsync
executor_class
=
TPUExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
assert
distributed_executor_backend
is
None
,
(
assert
distributed_executor_backend
is
None
,
(
"Distributed execution is not supported with the CPU backend."
)
"Distributed execution is not supported with the CPU backend."
)
...
...
vllm/engine/llm_engine.py
View file @
1a8bfd92
...
@@ -341,6 +341,9 @@ class LLMEngine:
...
@@ -341,6 +341,9 @@ class LLMEngine:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutor
from
vllm.executor.neuron_executor
import
NeuronExecutor
executor_class
=
NeuronExecutor
executor_class
=
NeuronExecutor
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
from
vllm.executor.tpu_executor
import
TPUExecutor
executor_class
=
TPUExecutor
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
from
vllm.executor.cpu_executor
import
CPUExecutor
from
vllm.executor.cpu_executor
import
CPUExecutor
executor_class
=
CPUExecutor
executor_class
=
CPUExecutor
...
...
vllm/envs.py
View file @
1a8bfd92
...
@@ -27,6 +27,7 @@ if TYPE_CHECKING:
...
@@ -27,6 +27,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_XLA_CACHE_PATH
:
str
=
"~/.vllm/xla_cache/"
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"spawn"
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"spawn"
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
...
@@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Default is 5 seconds
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT"
:
"VLLM_IMAGE_FETCH_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_XLA_CACHE_PATH"
,
"~/.vllm/xla_cache/"
),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/executor/tpu_executor.py
0 → 100644
View file @
1a8bfd92
from
typing
import
List
,
Set
,
Tuple
import
torch
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
logger
=
init_logger
(
__name__
)
class
TPUExecutor
(
ExecutorBase
):
def
_init_executor
(
self
)
->
None
:
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
,
(
"Chunked prefill is not yet supported for TPU backend"
)
assert
not
self
.
speculative_config
,
(
"Speculative decoding is not yet supported for TPU backend"
)
if
self
.
model_config
.
dtype
in
(
torch
.
float16
,
torch
.
float32
):
logger
.
warning
(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead."
,
self
.
model_config
.
dtype
)
self
.
model_config
.
dtype
=
torch
.
bfloat16
# Instantiate the worker and load the model to the device.
self
.
_init_worker
()
def
_init_worker
(
self
):
from
vllm.worker.tpu_worker
import
TPUWorker
assert
self
.
parallel_config
.
world_size
==
1
,
(
"TPUExecutor currently only supports a single TPU chip."
)
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
self
.
driver_worker
=
TPUWorker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
self
.
device_config
,
self
.
cache_config
,
self
.
load_config
,
self
.
vision_language_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
)
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
load_model
()
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
)
->
None
:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger
.
info
(
"# TPU blocks: %d, # CPU blocks: %d"
,
num_gpu_blocks
,
num_cpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return
self
.
driver_worker
.
determine_num_available_blocks
()
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
(
"LoRA is not implemented for TPU backend."
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"LoRA is not implemented for TPU backend."
)
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
(
"LoRA is not implemented for TPU backend."
)
def
check_health
(
self
)
->
None
:
# TPUExecutor will always be healthy as long as it's running.
return
class
TPUExecutorAsync
(
TPUExecutor
,
ExecutorAsyncBase
):
async
def
execute_model_async
(
self
,
sexecute_model_req
:
ExecuteModelRequest
,
)
->
SamplerOutput
:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
sexecute_model_req
)
return
output
vllm/model_executor/custom_op.py
View file @
1a8bfd92
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.utils
import
is_cpu
,
is_hip
from
vllm.utils
import
is_cpu
,
is_hip
,
is_tpu
class
CustomOp
(
nn
.
Module
):
class
CustomOp
(
nn
.
Module
):
...
@@ -56,5 +56,7 @@ class CustomOp(nn.Module):
...
@@ -56,5 +56,7 @@ class CustomOp(nn.Module):
return
self
.
forward_hip
return
self
.
forward_hip
elif
is_cpu
():
elif
is_cpu
():
return
self
.
forward_cpu
return
self
.
forward_cpu
elif
is_tpu
():
return
self
.
forward_tpu
else
:
else
:
return
self
.
forward_cuda
return
self
.
forward_cuda
vllm/model_executor/layers/rotary_embedding.py
View file @
1a8bfd92
...
@@ -28,6 +28,7 @@ import torch
...
@@ -28,6 +28,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.utils
import
is_tpu
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
...
@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return
x
.
flatten
(
-
2
)
return
x
.
flatten
(
-
2
)
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
x_
=
torch
.
view_as_complex
(
torch
.
stack
(
torch
.
chunk
(
x
.
transpose
(
1
,
2
).
float
(),
2
,
dim
=-
1
),
dim
=-
1
))
x_out
=
torch
.
view_as_real
(
x_
*
freqs_cis
).
type_as
(
x
)
x_out
=
torch
.
cat
(
torch
.
chunk
(
x_out
,
2
,
dim
=-
1
),
dim
=-
2
)
x_out
=
x_out
.
reshape
(
x_out
.
shape
[
0
],
x_out
.
shape
[
1
],
x_out
.
shape
[
2
],
-
1
).
transpose
(
1
,
2
)
return
x_out
class
RotaryEmbedding
(
CustomOp
):
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
"""Original rotary positional embedding."""
...
@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
...
@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
self
.
use_native2
=
is_tpu
()
and
is_neox_style
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
if
not
self
.
use_native2
:
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
else
:
cos
,
sin
=
cache
.
chunk
(
2
,
dim
=-
1
)
freqs_cis
=
cos
+
1j
*
sin
self
.
register_buffer
(
"freqs_cis"
,
freqs_cis
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
"""Compute the inverse frequency."""
...
@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
...
@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
"""A PyTorch-native implementation equivalent to forward().
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
...
@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
...
@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
key
=
key
.
flatten
(
-
2
)
key
=
key
.
flatten
(
-
2
)
return
query
,
key
return
query
,
key
def
forward_native2
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
if
positions
.
dim
()
==
1
:
batch_size
=
1
seq_len
=
positions
.
shape
[
0
]
else
:
batch_size
,
seq_len
=
positions
.
shape
if
offsets
is
not
None
:
positions
=
positions
+
offsets
freqs_cis
=
self
.
freqs_cis
.
index_select
(
0
,
positions
.
flatten
())
freqs_cis
=
freqs_cis
.
view
(
batch_size
,
1
,
seq_len
,
-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
freqs_cis
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
freqs_cis
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -161,6 +221,17 @@ class RotaryEmbedding(CustomOp):
...
@@ -161,6 +221,17 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
forward_tpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_fn
=
(
self
.
forward_native2
if
self
.
use_native2
else
self
.
forward_native
)
return
forward_fn
(
positions
,
query
,
key
,
offsets
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
...
...
vllm/model_executor/model_loader/loader.py
View file @
1a8bfd92
...
@@ -34,6 +34,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -34,6 +34,7 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator
,
safetensors_weights_iterator
)
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_tpu
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -227,12 +228,26 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -227,12 +228,26 @@ class DefaultModelLoader(BaseModelLoader):
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
assert
use_safetensors
is
False
return
np_cache_weights_iterator
(
model_name_or_path
,
weights_iterator
=
np_cache_weights_iterator
(
self
.
load_config
.
download_dir
,
model_name_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_folder
,
hf_weights_files
)
hf_weights_files
)
if
use_safetensors
:
elif
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
return
pt_weights_iterator
(
hf_weights_files
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
if
is_tpu
():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import
torch_xla.core.xla_model
as
xm
def
_xla_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
xm
.
mark_step
()
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
return
weights_iterator
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
...
...
vllm/utils.py
View file @
1a8bfd92
...
@@ -146,6 +146,15 @@ def is_neuron() -> bool:
...
@@ -146,6 +146,15 @@ def is_neuron() -> bool:
return
transformers_neuronx
is
not
None
return
transformers_neuronx
is
not
None
@
lru_cache
(
maxsize
=
None
)
def
is_tpu
()
->
bool
:
try
:
import
libtpu
except
ImportError
:
libtpu
=
None
return
libtpu
is
not
None
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
"""Returns the maximum shared memory per thread block in bytes."""
...
@@ -546,6 +555,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
...
@@ -546,6 +555,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
return
tensor
return
tensor
def
get_dtype_size
(
dtype
:
torch
.
dtype
)
->
int
:
"""Get the size of the data type in bytes."""
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
def
merge_dicts
(
dict1
:
Dict
[
Any
,
List
[
Any
]],
def
merge_dicts
(
dict1
:
Dict
[
Any
,
List
[
Any
]],
dict2
:
Dict
[
Any
,
List
[
Any
]])
->
Dict
[
Any
,
List
[
Any
]]:
dict2
:
Dict
[
Any
,
List
[
Any
]])
->
Dict
[
Any
,
List
[
Any
]]:
"""Merge 2 dicts that have key -> List of items.
"""Merge 2 dicts that have key -> List of items.
...
...
vllm/worker/cache_engine.py
View file @
1a8bfd92
...
@@ -6,7 +6,8 @@ import torch
...
@@ -6,7 +6,8 @@ import torch
from
vllm.attention
import
get_attn_backend
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
is_pin_memory_available
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
,
is_pin_memory_available
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -108,9 +109,5 @@ class CacheEngine:
...
@@ -108,9 +109,5 @@ class CacheEngine:
dtype
=
model_config
.
dtype
dtype
=
model_config
.
dtype
else
:
else
:
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
dtype_size
=
_
get_dtype_size
(
dtype
)
dtype_size
=
get_dtype_size
(
dtype
)
return
dtype_size
*
total
return
dtype_size
*
total
def
_get_dtype_size
(
dtype
:
torch
.
dtype
)
->
int
:
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment