Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
52f07e3d
Unverified
Commit
52f07e3d
authored
Jul 26, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 26, 2024
Browse files
[Hardware][TPU] Implement tensor parallelism with Ray (#5871)
parent
14dbd5a7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
365 additions
and
21 deletions
+365
-21
requirements-tpu.txt
requirements-tpu.txt
+1
-0
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+2
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+8
-2
vllm/executor/ray_tpu_executor.py
vllm/executor/ray_tpu_executor.py
+313
-0
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+31
-11
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+10
-6
No files found.
requirements-tpu.txt
View file @
52f07e3d
...
@@ -4,4 +4,5 @@
...
@@ -4,4 +4,5 @@
# Dependencies for TPU
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
# You can install the dependencies in Dockerfile.tpu.
ray
triton # To avoid import errors
triton # To avoid import errors
vllm/attention/backends/pallas.py
View file @
52f07e3d
...
@@ -55,8 +55,8 @@ class PallasMetadata(AttentionMetadata):
...
@@ -55,8 +55,8 @@ class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# Currently, input sequences can only contain all prefills
# or all decoding.
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
context_lens
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
...
...
vllm/engine/llm_engine.py
View file @
52f07e3d
...
@@ -394,6 +394,12 @@ class LLMEngine:
...
@@ -394,6 +394,12 @@ class LLMEngine:
from
vllm.executor.neuron_executor
import
NeuronExecutor
from
vllm.executor.neuron_executor
import
NeuronExecutor
executor_class
=
NeuronExecutor
executor_class
=
NeuronExecutor
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
if
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutor
executor_class
=
RayTPUExecutor
else
:
assert
distributed_executor_backend
is
None
from
vllm.executor.tpu_executor
import
TPUExecutor
from
vllm.executor.tpu_executor
import
TPUExecutor
executor_class
=
TPUExecutor
executor_class
=
TPUExecutor
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
...
...
vllm/executor/ray_tpu_executor.py
0 → 100644
View file @
52f07e3d
import
asyncio
import
os
from
collections
import
defaultdict
from
itertools
import
islice
,
repeat
from
typing
import
(
TYPE_CHECKING
,
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
)
import
vllm.envs
as
envs
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.executor.tpu_executor
import
TPUExecutor
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
logger
=
init_logger
(
__name__
)
class
RayTPUExecutor
(
TPUExecutor
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self
.
parallel_worker_tasks
:
Optional
[
Union
[
Any
,
Awaitable
[
Any
]]]
=
None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self
.
extra_execute_model_run_workers_kwargs
:
Dict
[
str
,
Any
]
=
{}
super
().
__init__
(
*
args
,
**
kwargs
)
def
_init_executor
(
self
)
->
None
:
assert
self
.
parallel_config
.
distributed_executor_backend
==
"ray"
placement_group
=
self
.
parallel_config
.
placement_group
# Disable Ray usage stats collection.
ray_usage
=
os
.
environ
.
get
(
"RAY_USAGE_STATS_ENABLED"
,
"0"
)
if
ray_usage
!=
"1"
:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
# Create the parallel TPU workers.
self
.
_init_workers_ray
(
placement_group
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self
.
driver_dummy_worker
:
Optional
[
RayWorkerWrapper
]
=
None
# The remaining workers are the actual ray actors.
self
.
workers
:
List
[
RayWorkerWrapper
]
=
[]
# Create the workers.
driver_ip
=
get_ip
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"TPU"
,
0
):
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
assert
self
.
speculative_config
is
None
worker_module_name
=
"vllm.worker.tpu_worker"
worker_class_name
=
"TPUWorker"
worker
=
ray
.
remote
(
num_cpus
=
0
,
resources
=
{
"TPU"
:
1
},
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
else
:
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
if
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
"Ray does not allocate any TPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"TPU node."
)
# Get the set of TPU IDs used on each node.
worker_node_and_gpu_ids
=
self
.
_run_workers
(
"get_node_and_gpu_ids"
,
use_dummy_driver
=
True
)
node_workers
=
defaultdict
(
list
)
for
i
,
(
node_id
,
_
)
in
enumerate
(
worker_node_and_gpu_ids
):
node_workers
[
node_id
].
append
(
i
)
VLLM_INSTANCE_ID
=
get_vllm_instance_id
()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables
=
[({
"VLLM_INSTANCE_ID"
:
VLLM_INSTANCE_ID
,
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
},
)
for
_
in
worker_node_and_gpu_ids
]
self
.
_run_workers
(
"update_environment_variables"
,
all_args
=
all_args_to_update_environment_variables
)
if
len
(
node_workers
)
==
1
:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip
=
"127.0.0.1"
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs
=
[
self
.
_get_worker_kwargs
(
local_rank
=
node_workers
[
node_id
].
index
(
rank
),
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
)
for
rank
,
(
node_id
,
_
)
in
enumerate
(
worker_node_and_gpu_ids
)
]
self
.
_run_workers
(
"init_worker"
,
all_kwargs
=
init_worker_all_kwargs
)
self
.
_run_workers
(
"init_device"
)
self
.
_run_workers
(
"load_model"
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
)
def
_driver_execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return
self
.
driver_worker
.
execute_method
(
"execute_model"
,
execute_model_req
)
def
_run_workers
(
self
,
method
:
str
,
*
args
,
async_run_remote_workers_only
:
bool
=
False
,
all_args
:
Optional
[
List
[
Tuple
[
Any
,
...]]]
=
None
,
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
use_dummy_driver
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers. Can be used in the following
ways:
- async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than blocking
on the results.
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if
max_concurrent_workers
:
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
count
=
len
(
self
.
workers
)
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
else
islice
(
all_args
,
1
,
None
)
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
else
islice
(
all_kwargs
,
1
,
None
)
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
]
if
async_run_remote_workers_only
:
# Just return futures
return
ray_worker_outputs
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
method
,
*
driver_args
,
**
driver_kwargs
)
else
:
assert
self
.
driver_dummy_worker
is
not
None
driver_worker_output
=
ray
.
get
(
self
.
driver_dummy_worker
.
execute_method
.
remote
(
method
,
*
driver_args
,
**
driver_kwargs
))
# Get the results of the ray workers.
if
self
.
workers
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
num_blocks
=
self
.
_run_workers
(
"determine_num_available_blocks"
,
)
num_tpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
return
num_tpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
logger
.
info
(
"# TPU blocks: %d, # CPU blocks: %d"
,
num_gpu_blocks
,
num_cpu_blocks
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
self
.
_run_workers
(
"initialize_cache"
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
)
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
if
self
.
parallel_worker_tasks
is
None
:
self
.
parallel_worker_tasks
=
self
.
_run_workers
(
"start_worker_execution_loop"
,
async_run_remote_workers_only
=
True
,
**
self
.
extra_execute_model_run_workers_kwargs
)
# Only the driver worker returns the sampling results.
return
self
.
_driver_execute_model
(
execute_model_req
)
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
if
self
.
parallel_worker_tasks
is
None
:
return
self
.
_driver_execute_model
()
parallel_worker_tasks
=
self
.
parallel_worker_tasks
self
.
parallel_worker_tasks
=
None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self
.
_wait_for_tasks_completion
(
parallel_worker_tasks
)
class
RayTPUExecutorAsync
(
RayTPUExecutor
,
ExecutorAsyncBase
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec_method
=
make_async
(
self
.
driver_worker
.
execute_method
)
async
def
execute_model_async
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
if
self
.
parallel_worker_tasks
is
None
:
# Start model execution loop running in the parallel workers
self
.
parallel_worker_tasks
=
asyncio
.
create_task
(
self
.
_start_worker_execution_loop
())
# Only the driver worker returns the sampling results.
return
await
self
.
_driver_execute_model_async
(
execute_model_req
)
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
if
self
.
parallel_worker_tasks
is
None
:
return
await
self
.
_driver_execute_model_async
()
parallel_worker_tasks
=
self
.
parallel_worker_tasks
self
.
parallel_worker_tasks
=
None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await
parallel_worker_tasks
async
def
_driver_execute_model_async
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
return
await
self
.
driver_exec_method
(
"execute_model"
,
execute_model_req
)
async
def
_start_worker_execution_loop
(
self
):
coros
=
[
worker
.
execute_method
.
remote
(
"start_worker_execution_loop"
)
for
worker
in
self
.
workers
]
return
await
asyncio
.
gather
(
*
coros
)
vllm/worker/tpu_model_runner.py
View file @
52f07e3d
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
unittest.mock
import
patch
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -45,6 +46,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
...
@@ -45,6 +46,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
num_samples
:
int
num_samples
:
int
best_of
:
List
[
int
]
best_of
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
seq_groups
:
List
[
List
[
int
]]
virtual_engine
:
int
=
0
def
as_broadcastable_tensor_dict
(
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
...
@@ -55,6 +57,9 @@ class ModelInputForTPU(ModelRunnerInputBase):
...
@@ -55,6 +57,9 @@ class ModelInputForTPU(ModelRunnerInputBase):
"t"
:
self
.
t
,
"t"
:
self
.
t
,
"p"
:
self
.
p
,
"p"
:
self
.
p
,
"num_samples"
:
self
.
num_samples
,
"num_samples"
:
self
.
num_samples
,
"best_of"
:
self
.
best_of
,
"seq_groups"
:
self
.
seq_groups
,
"virtual_engine"
:
self
.
virtual_engine
,
}
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
return
tensor_dict
...
@@ -113,6 +118,20 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -113,6 +118,20 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank
=
xm
.
get_ordinal
()
with
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank"
,
return_value
=
xm_tp_rank
):
model
=
get_model
(
model
=
get_model
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
...
@@ -463,10 +482,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -463,10 +482,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
tensor_dict
,
attn_backend
=
self
.
attn_backend
)
tensor_dict
,
attn_backend
=
self
.
attn_backend
)
return
model_input
return
model_input
@
torch
.
no_grad
()
def
execute_model
(
def
execute_model
(
self
,
self
,
model_input
:
ModelInputForTPU
,
model_input
:
ModelInputForTPU
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
Optional
[
List
[
Any
]],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
...
...
vllm/worker/tpu_worker.py
View file @
52f07e3d
...
@@ -70,13 +70,13 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -70,13 +70,13 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
self
.
device
=
xm
.
xla_device
()
self
.
device_config
.
device
=
self
.
device
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
# NOTE(woosuk): This is just a hack to initialize the TP group.
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# This cannot perform the actual communication ops.
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment
(
init_distributed_environment
(
world_size
=
self
.
parallel_config
.
world_size
,
world_size
=
self
.
parallel_config
.
world_size
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
...
@@ -88,6 +88,11 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -88,6 +88,11 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
parallel_config
.
tensor_parallel_size
,
self
.
parallel_config
.
tensor_parallel_size
,
self
.
parallel_config
.
pipeline_parallel_size
)
self
.
parallel_config
.
pipeline_parallel_size
)
# Device initialization should happen after initializing the distributed
# runtime.
self
.
device
=
xm
.
xla_device
()
self
.
device_config
.
device
=
self
.
device
# Set random seed.
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
xm
.
set_rng_state
(
self
.
model_config
.
seed
,
self
.
device
)
xm
.
set_rng_state
(
self
.
model_config
.
seed
,
self
.
device
)
...
@@ -200,8 +205,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -200,8 +205,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
@
property
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
def
do_metadata_broadcast
(
self
)
->
bool
:
# TODO(woosuk): Support TP.
return
self
.
parallel_config
.
tensor_parallel_size
>
1
return
False
@
property
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment