Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5753ff5
Commit
a5753ff5
authored
Jun 19, 2024
by
zhuwenwen
Browse files
v0.5.0.post1
parents
21c06ecb
0f0d8bc0
Changes
108
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
798 additions
and
9 deletions
+798
-9
vllm/utils.py
vllm/utils.py
+49
-0
vllm/version.py
vllm/version.py
+1
-0
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+3
-6
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+2
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+12
-1
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+525
-0
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+198
-0
vllm/worker/worker.py
vllm/worker/worker.py
+8
-0
No files found.
vllm/utils.py
View file @
a5753ff5
...
@@ -146,6 +146,15 @@ def is_neuron() -> bool:
...
@@ -146,6 +146,15 @@ def is_neuron() -> bool:
return
transformers_neuronx
is
not
None
return
transformers_neuronx
is
not
None
@
lru_cache
(
maxsize
=
None
)
def
is_tpu
()
->
bool
:
try
:
import
libtpu
except
ImportError
:
libtpu
=
None
return
libtpu
is
not
None
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
"""Returns the maximum shared memory per thread block in bytes."""
...
@@ -546,6 +555,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
...
@@ -546,6 +555,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
return
tensor
return
tensor
def
get_dtype_size
(
dtype
:
torch
.
dtype
)
->
int
:
"""Get the size of the data type in bytes."""
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
def
merge_dicts
(
dict1
:
Dict
[
Any
,
List
[
Any
]],
def
merge_dicts
(
dict1
:
Dict
[
Any
,
List
[
Any
]],
dict2
:
Dict
[
Any
,
List
[
Any
]])
->
Dict
[
Any
,
List
[
Any
]]:
dict2
:
Dict
[
Any
,
List
[
Any
]])
->
Dict
[
Any
,
List
[
Any
]]:
"""Merge 2 dicts that have key -> List of items.
"""Merge 2 dicts that have key -> List of items.
...
@@ -679,3 +693,38 @@ def deprecate_kwargs(
...
@@ -679,3 +693,38 @@ def deprecate_kwargs(
return
inner
# type: ignore
return
inner
# type: ignore
return
wrapper
return
wrapper
@
lru_cache
(
maxsize
=
8
)
def
_cuda_device_count_stateless
(
cuda_visible_devices
:
Optional
[
str
]
=
None
)
->
int
:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import
torch.cuda
import
torch.version
if
not
torch
.
cuda
.
_is_compiled
():
return
0
# bypass _device_count_nvml() if rocm (not supported)
nvml_count
=
-
1
if
torch
.
version
.
hip
else
torch
.
cuda
.
_device_count_nvml
()
r
=
torch
.
_C
.
_cuda_getDeviceCount
()
if
nvml_count
<
0
else
nvml_count
return
r
def
cuda_device_count_stateless
()
->
int
:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return
_cuda_device_count_stateless
(
envs
.
CUDA_VISIBLE_DEVICES
)
vllm/version.py
0 → 100644
View file @
a5753ff5
__version__
=
"0.5.0.post1"
vllm/worker/cache_engine.py
View file @
a5753ff5
...
@@ -6,7 +6,8 @@ import torch
...
@@ -6,7 +6,8 @@ import torch
from
vllm.attention
import
get_attn_backend
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
is_pin_memory_available
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
,
is_pin_memory_available
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -108,9 +109,5 @@ class CacheEngine:
...
@@ -108,9 +109,5 @@ class CacheEngine:
dtype
=
model_config
.
dtype
dtype
=
model_config
.
dtype
else
:
else
:
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
dtype_size
=
_
get_dtype_size
(
dtype
)
dtype_size
=
get_dtype_size
(
dtype
)
return
dtype_size
*
total
return
dtype_size
*
total
def
_get_dtype_size
(
dtype
:
torch
.
dtype
)
->
int
:
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
vllm/worker/cpu_model_runner.py
View file @
a5753ff5
...
@@ -343,8 +343,8 @@ class CPUModelRunner:
...
@@ -343,8 +343,8 @@ class CPUModelRunner:
"kv_caches"
:
kv_caches
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
"attn_metadata"
:
attn_metadata
,
}
}
if
self
.
vision_language_config
:
if
self
.
vision_language_config
and
multi_modal_input
is
not
None
:
execute_model_kwargs
.
update
(
{
"image_input"
:
multi_modal_input
}
)
execute_model_kwargs
.
update
(
multi_modal_input
)
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
...
...
vllm/worker/model_runner.py
View file @
a5753ff5
...
@@ -13,13 +13,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -13,13 +13,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.
communication_op
import
graph_capture
from
vllm.distributed.
parallel_state
import
graph_capture
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
...
@@ -222,6 +223,16 @@ class ModelRunner:
...
@@ -222,6 +223,16 @@ class ModelRunner:
max_size
=
max_size
,
max_size
=
max_size
,
)
)
def
save_tensorized_model
(
self
,
tensorizer_config
:
TensorizerConfig
,
)
->
None
:
from
vllm.model_executor.model_loader.loader
import
TensorizerLoader
TensorizerLoader
.
save_model
(
self
.
model
,
tensorizer_config
=
tensorizer_config
,
)
def
get_max_block_per_batch
(
self
)
->
int
:
def
get_max_block_per_batch
(
self
)
->
int
:
block_size
=
self
.
block_size
block_size
=
self
.
block_size
return
(
self
.
max_seq_len_to_capture
+
block_size
-
1
)
//
block_size
return
(
self
.
max_seq_len_to_capture
+
block_size
-
1
)
//
block_size
...
...
vllm/worker/tpu_model_runner.py
0 → 100644
View file @
a5753ff5
import
time
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch_xla.core.xla_model
as
xm
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
make_tensor_with_pad
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
0
# FIXME(woosuk)
class
TPUModelRunner
:
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
vision_language_config
=
vision_language_config
self
.
block_size
=
self
.
cache_config
.
block_size
self
.
max_num_blocks_per_seq
=
(
self
.
model_config
.
max_model_len
//
self
.
block_size
)
self
.
block_tables
=
np
.
zeros
(
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
max_num_blocks_per_seq
),
dtype
=
np
.
int32
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
False
,
)
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
model
=
get_model
(
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
device_config
=
self
.
device_config
,
parallel_config
=
self
.
parallel_config
,
cache_config
=
self
.
cache_config
,
scheduler_config
=
self
.
scheduler_config
,
vision_language_config
=
self
.
vision_language_config
,
lora_config
=
None
,
)
xm
.
wait_device_ops
()
model
=
ModelWrapper
(
model
)
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
)
def
_dummy_run
(
self
,
batch_size
:
int
,
seq_len
:
int
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
is_prompt
:
bool
,
)
->
None
:
if
is_prompt
:
seq_len
=
(
seq_len
+
15
)
//
16
*
16
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
batch_size
,
num_prefill_tokens
=
batch_size
*
seq_len
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
block_tables
=
None
,
context_lens
=
None
,
)
input_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
else
:
assert
seq_len
==
1
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
block_tables
=
torch
.
zeros
(
(
batch_size
,
self
.
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
context_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
*
seq_len
,
slot_mapping
=
slot_mapping
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
)
t
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
# Dummy run.
self
.
model
(
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
input_lens
,
t
,
p
)
def
warmup_model
(
self
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
None
:
# Prefill
logger
.
info
(
"Compiling the model with different input shapes..."
)
start
=
time
.
time
()
for
batch_size
in
[
1
]:
seq_len
=
16
while
True
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
True
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
if
seq_len
>=
self
.
model_config
.
max_model_len
:
break
num_tokens
=
batch_size
*
seq_len
if
num_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
:
break
seq_len
=
seq_len
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for prefill done in %.2f s."
,
end
-
start
)
# Decode
start
=
time
.
time
()
seq_len
=
1
batch_size
=
1
while
True
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
False
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
if
batch_size
>=
self
.
scheduler_config
.
max_num_seqs
:
break
batch_size
=
batch_size
+
16
if
batch_size
>=
16
else
batch_size
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for decode done in %.2f s."
,
end
-
start
)
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
):
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
prompt_lens
:
List
[
int
]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
# Could include output tokens when a request is preempted.
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
append
(
prompt_tokens
)
input_positions
.
append
(
list
(
range
(
prompt_len
)))
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
slot_mapping
.
append
([])
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
assert
len
(
prompt_lens
)
>
0
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
sum
(
prompt_lens
)
# Add paddings to make the shape [batch_size, max_prompt_len] where
# max_prompt_len is smallest power of 2 that is greater than or equal
# to the maximum prompt length.
# We need the 2D input shape because the Pallas FlashAttention kernel
# does not support packed 1D inputs.
# We pad the seq_len to powers of 2 to reduce the compilation overhead.
max_prompt_len
=
_get_padded_prefill_len
(
max
(
prompt_lens
))
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
make_tensor_with_pad
(
slot_mapping
,
max_prompt_len
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
# NOTE: This is not used.
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
block_tables
=
None
,
context_lens
=
None
,
)
return
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
):
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
num_seq_groups
=
len
(
seq_group_metadata_list
)
batch_size
=
_get_padded_batch_size
(
num_seq_groups
)
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
assert
not
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
([
position
])
context_lens
.
append
(
seq_len
)
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
self
.
block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
num_paddings
=
batch_size
-
num_seq_groups
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
input_positions
=
input_positions
+
[[
0
]]
*
num_paddings
slot_mapping
=
slot_mapping
+
[[
_PAD_SLOT_ID
]]
*
num_paddings
context_lens
=
context_lens
+
[
0
]
*
num_paddings
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lens
=
torch
.
tensor
([
1
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
slot_mapping
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
)
return
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
padded_batch_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
t
=
[]
p
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
sampling_params
is
not
None
sampling_params
=
seq_group_metadata
.
sampling_params
t
.
append
(
sampling_params
.
temperature
if
sampling_params
.
temperature
>=
1e-5
else
1e-5
)
p
.
append
(
sampling_params
.
top_p
)
num_paddings
=
padded_batch_size
-
len
(
seq_group_metadata_list
)
t
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
return
t
,
p
def
prepare_inputs
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
):
assert
seq_group_metadata_list
is
not
None
assert
len
(
seq_group_metadata_list
)
>
0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
if
seq_group_metadata_list
[
0
].
is_prompt
:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
padded_batch_size
=
inputs
[
0
].
shape
[
0
]
sample_inputs
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
return
inputs
+
sample_inputs
def
_execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
List
[
CompletionSequenceGroupOutput
]:
inputs
=
self
.
prepare_inputs
(
seq_group_metadata_list
)
next_token_ids
=
self
.
model
(
inputs
[
0
],
inputs
[
1
],
kv_caches
,
*
inputs
[
2
:])
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
i
=
0
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_outputs
=
[]
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
i
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
Logprob
(
0.0
)}))
i
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
sampler_outputs
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
assert
seq_group_metadata_list
is
not
None
if
seq_group_metadata_list
[
0
].
is_prompt
:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
# aware of this limitation, we need to handle batch size > 1
# internally by calling the model multiple times and concatenating
# the outputs.
# FIXME(woosuk): This is a temporary hack to not change the existing
# scheduler. We need to fix this in the future.
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
sampler_outputs
+=
self
.
_execute_model
([
seq_group_metadata
],
kv_caches
)
else
:
sampler_outputs
=
self
.
_execute_model
(
seq_group_metadata_list
,
kv_caches
)
return
SamplerOutput
(
sampler_outputs
)
class
ModelWrapper
(
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
self
.
model
=
model
.
eval
()
def
forward
(
self
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
attn_metadata
:
AttentionMetadata
,
input_lens
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
"""
batch_size
,
seq_len
=
token_ids
.
shape
# Calculate the positions to sample from.
base_indicies
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_lens
.
device
)
*
seq_len
logits_indices
=
base_indicies
+
input_lens
-
1
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
[],
selected_token_indices
=
logits_indices
,
categorized_sample_indices
=
{},
num_prompts
=
attn_metadata
.
num_prefills
,
)
# Skip this in memory profiling at initialization.
if
kv_caches
[
0
][
0
]
is
not
None
:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads
,
num_blocks
,
block_size
,
_
=
kv_caches
[
0
][
0
].
shape
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
slot_mapping
.
flatten
()
head_indicies
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
slot_mapping
.
device
,
dtype
=
slot_mapping
.
dtype
)
head_indicies
*=
block_size
*
num_blocks
slot_mapping
=
slot_mapping
.
repeat_interleave
(
num_kv_heads
).
view
(
-
1
,
num_kv_heads
)
slot_mapping
=
slot_mapping
+
head_indicies
.
view
(
1
,
-
1
)
slot_mapping
=
slot_mapping
.
flatten
()
attn_metadata
.
slot_mapping
=
slot_mapping
hidden_states
=
self
.
model
(
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
logits
/
t
.
unsqueeze
(
dim
=
1
)
# FIXME(woosuk): Disabled top-p sampling since it's too slow.
# logits = _apply_top_p(logits, p.unsqueeze(dim=1))
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
# FIXME(woosuk): best_of > 1 is not supported.
next_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
dim
=
1
)
return
next_token_ids
def
_get_padded_prefill_len
(
x
:
int
)
->
int
:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if
x
<=
16
:
return
16
return
1
<<
(
x
-
1
).
bit_length
()
def
_get_padded_batch_size
(
batch_size
:
int
)
->
int
:
if
batch_size
<=
2
:
return
batch_size
elif
batch_size
<=
4
:
return
4
elif
batch_size
<=
8
:
return
8
else
:
return
((
batch_size
+
15
)
//
16
)
*
16
def
_apply_top_p
(
logits
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
logits_sorted
=
torch
.
sort
(
logits
,
dim
=-
1
,
descending
=
True
).
values
sorted_cum_probs
=
torch
.
cumsum
(
logits_sorted
.
softmax
(
dim
=-
1
),
dim
=-
1
)
cutoff_index
=
torch
.
sum
(
sorted_cum_probs
<
p
,
dim
=-
1
,
keepdim
=
True
)
cutoff_logit
=
torch
.
gather
(
logits_sorted
,
-
1
,
cutoff_index
)
logits
=
logits
.
masked_fill_
(
logits
<
cutoff_logit
,
-
float
(
"inf"
))
return
logits
vllm/worker/tpu_worker.py
0 → 100644
View file @
a5753ff5
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
logger
=
init_logger
(
__name__
)
class
TPUWorker
(
LoraNotSupportedWorkerBase
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
vision_language_config
:
Optional
[
VisionLanguageConfig
],
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
)
->
None
:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
vision_language_config
=
vision_language_config
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
assert
self
.
device_config
.
device_type
==
"tpu"
if
self
.
cache_config
.
cache_dtype
==
"auto"
:
self
.
cache_dtype
=
self
.
model_config
.
dtype
else
:
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
model_runner
=
TPUModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
,
vision_language_config
)
def
init_device
(
self
)
->
None
:
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
self
.
device
=
xm
.
xla_device
()
self
.
device_config
.
device
=
self
.
device
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
# NOTE(woosuk): This is just a hack to initialize the TP group.
# This cannot perform the actual communication ops.
init_distributed_environment
(
world_size
=
self
.
parallel_config
.
world_size
,
rank
=
self
.
rank
,
local_rank
=
self
.
local_rank
,
distributed_init_method
=
self
.
distributed_init_method
,
backend
=
"gloo"
,
)
ensure_model_parallel_initialized
(
self
.
parallel_config
.
tensor_parallel_size
,
self
.
parallel_config
.
pipeline_parallel_size
)
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
xm
.
set_rng_state
(
self
.
model_config
.
seed
,
self
.
device
)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch
.
_dynamo
.
config
.
cache_size_limit
=
128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): This does not completely eliminate the recompilation
# overhead because dynamo does not cache the compiled results.
xr
.
initialize_cache
(
os
.
path
.
expanduser
(
envs
.
VLLM_XLA_CACHE_PATH
),
readonly
=
False
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
kv_caches
=
[(
None
,
None
)
for
_
in
range
(
num_layers
)]
self
.
model_runner
.
_dummy_run
(
batch_size
=
1
,
seq_len
=
self
.
scheduler_config
.
max_num_batched_tokens
,
kv_caches
=
kv_caches
,
is_prompt
=
True
,
)
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
m
=
xm
.
get_memory_info
(
self
.
device
)
program_size
=
1024
*
1024
*
1024
# 1GB
free_bytes
=
max
(
m
[
"bytes_limit"
]
-
m
[
"bytes_used"
]
-
program_size
,
0
)
kv_cache_bytes
=
int
(
free_bytes
*
self
.
cache_config
.
gpu_memory_utilization
)
kv_cache_dtype_btyes
=
get_dtype_size
(
self
.
cache_dtype
)
block_size
=
self
.
cache_config
.
block_size
num_tpu_blocks
=
(
kv_cache_bytes
//
(
kv_cache_dtype_btyes
*
block_size
*
num_layers
*
2
*
head_size
*
num_kv_heads
))
num_tpu_blocks
=
(
num_tpu_blocks
//
8
)
*
8
# Round down to 8.
return
num_tpu_blocks
,
0
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
)
->
None
:
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
self
.
block_size
=
self
.
cache_config
.
block_size
dtype
=
self
.
cache_dtype
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
self
.
tpu_cache
=
[]
tpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_gpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
zeros
(
tpu_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
value_cache
=
torch
.
zeros_like
(
key_cache
)
self
.
tpu_cache
.
append
((
key_cache
,
value_cache
))
self
.
_warmup_model
()
def
_warmup_model
(
self
)
->
None
:
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
# for CUDA graphs. We should refactor this part.
if
not
self
.
model_config
.
enforce_eager
:
# Warm up the model with all possible input shapes so that
# compilation never happens during the actual execution.
# This may take ~30 mins for the first run and ~20 mins for the
# subsequent runs.
# If `enforce_eager` is True, the ahead-of-time compilation is
# skipped and the compilation happens during the actual execution,
# which is bad for performance but useful for development.
self
.
model_runner
.
warmup_model
(
self
.
tpu_cache
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
head_size
=
self
.
model_config
.
get_head_size
()
num_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
key_cache_block
=
self
.
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
return
dtype_size
*
total
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
if
execute_model_req
is
None
:
return
[]
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
num_seq_groups
=
len
(
seq_group_metadata_list
)
if
num_seq_groups
==
0
:
return
[]
# Currently, TPUWorker does not support swapping.
# TODO(woosuk): Support block copying.
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
,
(
"Swapping is not supported for the TPU backend."
)
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
,
(
"Swapping is not supported for the TPU backend."
)
assert
len
(
execute_model_req
.
blocks_to_copy
)
==
0
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
tpu_cache
)
return
[
output
]
vllm/worker/worker.py
View file @
a5753ff5
...
@@ -15,6 +15,7 @@ from vllm.distributed import (broadcast_tensor_dict,
...
@@ -15,6 +15,7 @@ from vllm.distributed import (broadcast_tensor_dict,
set_custom_all_reduce
)
set_custom_all_reduce
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
...
@@ -132,6 +133,13 @@ class Worker(WorkerBase):
...
@@ -132,6 +133,13 @@ class Worker(WorkerBase):
max_size
=
max_size
,
max_size
=
max_size
,
)
)
def
save_tensorized_model
(
self
,
tensorizer_config
:
TensorizerConfig
,
)
->
None
:
self
.
model_runner
.
save_tensorized_model
(
tensorizer_config
=
tensorizer_config
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Profiles the peak memory usage of the model to determine how many
"""Profiles the peak memory usage of the model to determine how many
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment