Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4172235a
Unverified
Commit
4172235a
authored
Sep 06, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 06, 2025
Browse files
[V0 deprecation] Deprecate V0 Neuron backend (#21159)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
848562bd
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
0 additions
and
1118 deletions
+0
-1118
vllm/platforms/__init__.py
vllm/platforms/__init__.py
+0
-25
vllm/platforms/interface.py
vllm/platforms/interface.py
+0
-4
vllm/platforms/neuron.py
vllm/platforms/neuron.py
+0
-151
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+0
-455
vllm/worker/neuron_worker.py
vllm/worker/neuron_worker.py
+0
-189
vllm/worker/neuronx_distributed_model_runner.py
vllm/worker/neuronx_distributed_model_runner.py
+0
-294
No files found.
vllm/platforms/__init__.py
View file @
4172235a
...
...
@@ -169,37 +169,12 @@ def cpu_platform_plugin() -> Optional[str]:
return
"vllm.platforms.cpu.CpuPlatform"
if
is_cpu
else
None
def
neuron_platform_plugin
()
->
Optional
[
str
]:
tnx_installed
=
False
nxd_installed
=
False
logger
.
debug
(
"Checking if Neuron platform is available."
)
try
:
import
transformers_neuronx
# noqa: F401
tnx_installed
=
True
logger
.
debug
(
"Confirmed Neuron platform is available because"
" transformers_neuronx is found."
)
except
ImportError
:
pass
try
:
import
neuronx_distributed_inference
# noqa: F401
nxd_installed
=
True
logger
.
debug
(
"Confirmed Neuron platform is available because"
" neuronx_distributed_inference is found."
)
except
ImportError
:
pass
is_neuron
=
tnx_installed
or
nxd_installed
return
"vllm.platforms.neuron.NeuronPlatform"
if
is_neuron
else
None
builtin_platform_plugins
=
{
'tpu'
:
tpu_platform_plugin
,
'cuda'
:
cuda_platform_plugin
,
'rocm'
:
rocm_platform_plugin
,
'xpu'
:
xpu_platform_plugin
,
'cpu'
:
cpu_platform_plugin
,
'neuron'
:
neuron_platform_plugin
,
}
...
...
vllm/platforms/interface.py
View file @
4172235a
...
...
@@ -73,7 +73,6 @@ class PlatformEnum(enum.Enum):
TPU
=
enum
.
auto
()
XPU
=
enum
.
auto
()
CPU
=
enum
.
auto
()
NEURON
=
enum
.
auto
()
OOT
=
enum
.
auto
()
UNSPECIFIED
=
enum
.
auto
()
...
...
@@ -164,9 +163,6 @@ class Platform:
def
is_cpu
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
CPU
def
is_neuron
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
NEURON
def
is_out_of_tree
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
OOT
...
...
vllm/platforms/neuron.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
import
os
from
functools
import
lru_cache
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.utils
import
DEFAULT_MAX_NUM_BATCHED_TOKENS
from
.interface
import
Platform
,
PlatformEnum
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
else
:
VllmConfig
=
None
logger
=
init_logger
(
__name__
)
class
NeuronFramework
(
enum
.
Enum
):
TRANSFORMERS_NEURONX
=
"transformers-neuronx"
NEURONX_DISTRIBUTED_INFERENCE
=
"neuronx-distributed-inference"
class
NeuronPlatform
(
Platform
):
_enum
=
PlatformEnum
.
NEURON
device_name
:
str
=
"neuron"
device_type
:
str
=
"neuron"
ray_device_key
:
str
=
"neuron_cores"
supported_quantization
:
list
[
str
]
=
[
"neuron_quant"
,
"fbgemm_fp8"
]
dist_backend
:
str
=
"gloo"
device_control_env_var
:
str
=
"NEURON_RT_VISIBLE_CORES"
@
classmethod
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
return
"neuron"
@
classmethod
def
is_async_output_supported
(
cls
,
enforce_eager
:
Optional
[
bool
])
->
bool
:
return
False
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
parallel_config
=
vllm_config
.
parallel_config
if
parallel_config
.
worker_cls
==
"auto"
:
parallel_config
.
worker_cls
=
\
"vllm.worker.neuron_worker.NeuronWorker"
if
parallel_config
.
world_size
>
1
:
parallel_config
.
distributed_executor_backend
=
"uni"
if
vllm_config
.
cache_config
and
vllm_config
.
model_config
:
# neuron needs block_size = max_model_len
vllm_config
.
cache_config
.
block_size
=
\
vllm_config
.
model_config
.
max_model_len
# type: ignore
if
vllm_config
.
model_config
and
vllm_config
.
model_config
.
use_mla
:
logger
.
info
(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled."
)
vllm_config
.
scheduler_config
.
enable_chunked_prefill
=
False
vllm_config
.
scheduler_config
.
chunked_prefill_enabled
=
False
vllm_config
.
scheduler_config
.
max_num_batched_tokens
=
max
(
vllm_config
.
scheduler_config
.
max_model_len
,
DEFAULT_MAX_NUM_BATCHED_TOKENS
)
@
classmethod
def
is_pin_memory_available
(
cls
)
->
bool
:
logger
.
warning
(
"Pin memory is not supported on Neuron."
)
return
False
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
if
envs
.
VLLM_USE_V1
:
return
"vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator"
# noqa
else
:
return
Platform
.
get_device_communicator_cls
()
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
return
True
@
classmethod
@
lru_cache
def
is_neuronx_distributed_inference
(
cls
)
->
bool
:
try
:
import
neuronx_distributed_inference
except
ImportError
:
neuronx_distributed_inference
=
None
return
neuronx_distributed_inference
is
not
None
@
classmethod
@
lru_cache
def
is_transformers_neuronx
(
cls
)
->
bool
:
try
:
import
transformers_neuronx
except
ImportError
:
transformers_neuronx
=
None
return
transformers_neuronx
is
not
None
def
get_neuron_framework_to_use
(
self
):
"""Return the specified framework if corresponding installations are
available.
If no framework is specified, use neuronx-distributed-inference by
default.
If that's unavailable, check and switch to transformers-neuronx.
"""
if
not
self
.
is_neuron
():
raise
AssertionError
(
f
"Neuron Framework unavailable for platform:
{
self
}
"
)
tnx_installed
=
self
.
is_transformers_neuronx
()
nxd_installed
=
self
.
is_neuronx_distributed_inference
()
specified_framework
=
os
.
environ
.
get
(
"VLLM_NEURON_FRAMEWORK"
)
tnx_framework
=
NeuronFramework
.
TRANSFORMERS_NEURONX
.
value
nxd_framework
=
NeuronFramework
.
NEURONX_DISTRIBUTED_INFERENCE
.
value
if
specified_framework
==
tnx_framework
and
tnx_installed
:
return
self
.
TRANSFORMERS_NEURONX
if
((
specified_framework
==
nxd_framework
and
nxd_installed
)
or
(
specified_framework
is
None
and
nxd_installed
)):
return
NeuronFramework
.
NEURONX_DISTRIBUTED_INFERENCE
if
specified_framework
is
None
and
tnx_installed
:
return
NeuronFramework
.
TRANSFORMERS_NEURONX
return
None
def
use_neuronx_distributed
(
self
):
"""
Return True if the framework determined in get_neuron_framework_to_use()
is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This
is used to select the Neuron model framework and framework-specific
configuration to apply during model compilation.
"""
nxd_framework
=
NeuronFramework
.
NEURONX_DISTRIBUTED_INFERENCE
return
self
.
get_neuron_framework_to_use
()
==
nxd_framework
def
use_transformers_neuronx
(
self
):
"""
Return True if the framework determined in get_neuron_framework_to_use()
is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used
to select the Neuron model framework and framework-specific
configuration to apply during model compilation.
"""
return
self
.
get_neuron_framework_to_use
(
)
==
NeuronFramework
.
TRANSFORMERS_NEURONX
vllm/worker/neuron_model_runner.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.multimodal
import
BatchedTensorInputs
,
MultiModalKwargs
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
ModelRunnerBase
,
ModelRunnerInputBase
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForNeuron
(
ModelRunnerInputBase
):
"""
Used by the NeuronModelRunner.
"""
input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
input_positions
:
Optional
[
torch
.
Tensor
]
=
None
input_block_ids
:
Optional
[
torch
.
Tensor
]
=
None
sampling_metadata
:
SamplingMetadata
=
None
multi_modal_kwargs
:
BatchedTensorInputs
=
None
adapter_ids
:
Optional
[
str
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
return
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"input_block_ids"
:
self
.
input_block_ids
,
"sampling_metadata"
:
self
.
sampling_metadata
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
}
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForNeuron"
:
return
ModelInputForNeuron
(
input_tokens
=
tensor_dict
[
"input_tokens"
],
input_positions
=
tensor_dict
[
"input_positions"
],
input_block_ids
=
tensor_dict
[
"input_block_ids"
],
sampling_metadata
=
tensor_dict
[
"sampling_metadata"
],
multi_modal_kwargs
=
tensor_dict
[
"multi_modal_kwargs"
],
)
class
NeuronModelRunner
(
ModelRunnerBase
[
ModelInputForNeuron
]):
"""A model runner for AWS Neuron hardware"""
# NEURON has an upper limit on the top_k
_MAX_NEURON_SAMPLING_TOP_K
=
256
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
):
ModelRunnerBase
.
__init__
(
self
,
vllm_config
)
if
(
self
.
model_config
is
not
None
and
self
.
model_config
.
get_sliding_window
()):
logger
.
warning
(
"Sliding window is not supported on Neuron. "
"The model will run without sliding window."
)
self
.
device_config
=
(
self
.
device_config
if
self
.
device_config
is
not
None
else
DeviceConfig
())
self
.
lora_config
=
vllm_config
.
lora_config
self
.
device
=
self
.
device_config
.
device
self
.
pin_memory
=
is_pin_memory_available
()
# Lazy initialization.
self
.
model
:
nn
.
Module
# initialize after load_model.
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
# turn off on-device sampling.
self
.
_on_device_sampling_disabled
=
int
(
os
.
getenv
(
"NEURON_ON_DEVICE_SAMPLING_DISABLED"
,
"0"
))
# NEURON needs to update sampling parameters when request IDs change
# across batches. This variable stores the previous batch's request IDs
# to determine if an update is needed.
self
.
_previous_batch_request_ids
:
List
[
str
]
=
[]
if
not
self
.
_on_device_sampling_disabled
:
self
.
_init_neuron_sampling
()
def
_init_neuron_sampling
(
self
)
->
None
:
if
current_platform
.
use_transformers_neuronx
():
from
transformers_neuronx.config
import
GenerationConfig
else
:
from
transformers
import
GenerationConfig
logger
.
warning
(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
)
self
.
model_config
.
neuron_sampling_params
=
GenerationConfig
(
max_length
=
self
.
scheduler_config
.
max_model_len
,
do_sample
=
True
,
per_batch_line
=
True
,
top_k
=
[
self
.
_MAX_NEURON_SAMPLING_TOP_K
]
\
*
self
.
scheduler_config
.
max_num_seqs
,
top_p
=
[
1.0
]
*
self
.
scheduler_config
.
max_num_seqs
,
temperature
=
[
1.0
]
*
self
.
scheduler_config
.
max_num_seqs
,
dynamic
=
True
,
global_top_k
=
self
.
_MAX_NEURON_SAMPLING_TOP_K
)
def
load_model
(
self
)
->
None
:
self
.
model
=
get_neuron_model
(
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
List
[
int
],
BatchedTensorInputs
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_block_ids
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
multi_modal_kwargs_list
:
List
[
MultiModalKwargs
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
seq_len
=
len
(
prompt_tokens
)
seq_lens
.
append
(
seq_len
)
input_tokens
.
append
(
prompt_tokens
)
input_positions
.
append
(
list
(
range
(
seq_len
)))
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
assert
len
(
block_table
)
==
1
input_block_ids
.
append
(
block_table
[
0
])
mm_kwargs
=
seq_group_metadata
.
multi_modal_data
if
mm_kwargs
:
mm_kwargs
=
self
.
process_multi_modal_data_neuron
(
mm_kwargs
)
multi_modal_kwargs_list
.
append
(
mm_kwargs
)
max_seq_len
=
max
(
seq_lens
)
assert
max_seq_len
>
0
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
pad
=
0
,
max_len
=
max_seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
pad
=
0
,
max_len
=
max_seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_block_ids
=
torch
.
tensor
(
input_block_ids
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
return
(
input_tokens
,
input_positions
,
input_block_ids
,
seq_lens
,
multi_modal_kwargs
)
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_block_ids
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
([
position
])
context_lens
.
append
(
seq_len
)
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
assert
len
(
block_table
)
==
1
input_block_ids
.
append
(
block_table
[
0
])
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
pad
=
0
,
max_len
=
1
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
pad
=
0
,
max_len
=
1
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
input_block_ids
=
torch
.
tensor
(
input_block_ids
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
input_tokens
,
input_positions
,
input_block_ids
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForNeuron
:
return
ModelInputForNeuron
.
from_broadcasted_tensor_dict
(
tensor_dict
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForNeuron
:
multi_modal_kwargs
=
None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_block_ids
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
seq_lens
=
None
if
not
self
.
_on_device_sampling_disabled
:
for
seq_group_metadata
in
seq_group_metadata_list
:
sampling_params
=
seq_group_metadata
.
sampling_params
top_k
,
top_p
,
temperature
=
(
self
.
_convert_to_neuron_sampling_params
(
sampling_params
))
sampling_params
.
top_k
=
top_k
sampling_params
.
top_p
=
top_p
sampling_params
.
temperature
=
temperature
# we need multi_modal_data for later tokens as well
multi_modal_kwargs_list
:
List
[
MultiModalKwargs
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
multi_modal_kwargs_list
.
append
(
mm_data
)
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens
,
self
.
device
,
self
.
pin_memory
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
if
current_platform
.
use_transformers_neuronx
(
)
and
not
self
.
_on_device_sampling_disabled
:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids
=
[
seq_group_meta_data
.
request_id
for
seq_group_meta_data
in
seq_group_metadata_list
]
if
current_batch_request_ids
!=
self
.
_previous_batch_request_ids
:
self
.
_update_neuron_sampling_params
(
seq_group_metadata_list
)
self
.
_previous_batch_request_ids
=
current_batch_request_ids
return
ModelInputForNeuron
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
input_block_ids
=
input_block_ids
,
sampling_metadata
=
sampling_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
)
def
_update_neuron_sampling_params
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]):
# Update Neuron sampling parameters (GenerationConfig in Neuron)
current_sampling_params
=
self
.
model_config
.
neuron_sampling_params
assert
current_sampling_params
is
not
None
,
(
f
"Failed to update sampling_params, "
f
"current sampling params is
{
current_sampling_params
}
"
)
is_update_needed
=
False
top_k
=
current_sampling_params
.
top_k
top_p
=
current_sampling_params
.
top_p
temperature
=
current_sampling_params
.
temperature
# The index of a sequence's sampling parameters in neuron is equal to
# its index in `input_block_ids`.
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_group_top_k
=
sampling_params
.
top_k
seq_group_top_p
=
sampling_params
.
top_p
seq_group_temperature
=
sampling_params
.
temperature
for
seq_id
in
seq_ids
:
index
=
seq_group_metadata
.
block_tables
[
seq_id
][
0
]
if
(
top_k
[
index
]
!=
seq_group_top_k
or
top_p
[
index
]
!=
seq_group_top_p
or
temperature
[
index
]
!=
seq_group_temperature
):
is_update_needed
=
True
top_k
[
index
]
=
seq_group_top_k
top_p
[
index
]
=
seq_group_top_p
temperature
[
index
]
=
seq_group_temperature
# update_generation_config is only available in transformers-neuronx
if
is_update_needed
and
current_platform
.
use_transformers_neuronx
():
self
.
model
.
model
.
update_generation_config
(
current_sampling_params
)
def
_convert_to_neuron_sampling_params
(
self
,
sampling_params
:
SamplingParams
)
->
Tuple
[
int
,
float
,
float
]:
# Returns the top_k, top_p and temperature parameters for neuron.
top_k
=
sampling_params
.
top_k
top_p
=
sampling_params
.
top_p
temperature
=
sampling_params
.
temperature
if
temperature
==
0.0
:
# Enable greedy sampling on zero temperature
return
(
1
,
1.0
,
1.0
)
if
top_k
<
1
or
top_k
>
self
.
_MAX_NEURON_SAMPLING_TOP_K
:
top_k
=
self
.
_MAX_NEURON_SAMPLING_TOP_K
return
(
top_k
,
top_p
,
temperature
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelInputForNeuron
,
kv_caches
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"NeuronModelRunner does not support multi-step execution."
)
# extract top_k, top_p and temperature from model_input for neuron
# forward call
sampling_params
=
(
torch
.
tensor
([[
seq_group
.
sampling_params
.
top_k
,
seq_group
.
sampling_params
.
top_p
,
seq_group
.
sampling_params
.
temperature
]
for
seq_group
in
model_input
.
sampling_metadata
.
seq_groups
]))
if
current_platform
.
use_neuronx_distributed
():
hidden_states
=
self
.
model
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
input_block_ids
=
model_input
.
input_block_ids
,
sampling_params
=
sampling_params
,
adapter_ids
=
model_input
.
adapter_ids
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
,
),
)
elif
current_platform
.
use_transformers_neuronx
():
# [TODO] validate on-device sampling
# The model signature may need change for on-device sampling
hidden_states
=
self
.
model
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
input_block_ids
=
model_input
.
input_block_ids
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
,
),
)
# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.
if
self
.
_on_device_sampling_disabled
:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
else
:
logits
=
hidden_states
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
[
output
]
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
def
process_multi_modal_data_neuron
(
self
,
mm_data
):
# this is a no-op for NeuronModelRunner
return
mm_data
def
remove_all_loras
(
self
):
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
):
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
vllm/worker/neuron_worker.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A Neuron worker class."""
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch.distributed
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.platforms.neuron
import
NeuronFramework
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.neuron_model_runner
import
NeuronModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
WorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
class
NeuronWorker
(
LocalOrDistributedWorkerBase
):
"""A worker class that executes the model on a group of neuron cores.
"""
model_runner
:
NeuronModelRunner
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
self
.
lora_config
=
vllm_config
.
lora_config
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
neuron_framework
=
current_platform
.
get_neuron_framework_to_use
()
if
neuron_framework
==
NeuronFramework
.
TRANSFORMERS_NEURONX
:
self
.
model_runner
=
self
.
get_tnx_model_runner
(
vllm_config
)
elif
neuron_framework
==
NeuronFramework
.
NEURONX_DISTRIBUTED_INFERENCE
:
self
.
model_runner
=
self
.
get_neuronx_distributed_model_runner
(
vllm_config
)
else
:
raise
NotImplementedError
(
"Specified framework"
+
f
"
{
os
.
environ
.
get
(
'VLLM_NEURON_FRAMEWORK'
)
}
"
+
" is either not installed or not supported."
+
" Supported frameworks: "
+
"[transformers-neuronx, neuronx-distributed-inference]"
)
def
get_tnx_model_runner
(
self
,
vllm_config
):
assert
(
self
.
lora_config
is
None
),
(
"LoRA is not supported for TransformersNeuronX "
"framework."
)
if
self
.
speculative_config
is
not
None
:
raise
NotImplementedError
(
"Speculative decoding is not supported for TransformersNeuronX"
)
return
NeuronModelRunner
(
vllm_config
=
vllm_config
)
def
get_neuronx_distributed_model_runner
(
self
,
vllm_config
):
from
vllm.worker.neuronx_distributed_model_runner
import
(
NeuronxDistributedModelRunner
)
if
self
.
speculative_config
is
not
None
:
assert
(
self
.
lora_config
is
None
),
(
"LoRA is not supported for Speculative Decoding"
)
raise
NotImplementedError
(
"Speculative decoding is not supported for NeuronxDistributed"
)
return
NeuronxDistributedModelRunner
(
vllm_config
=
vllm_config
)
def
init_device
(
self
)
->
None
:
self
.
init_distributed_environment
()
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
num_gpu_blocks
=
self
.
scheduler_config
.
max_num_seqs
+
1
# Swap not yet supported with Neuron backend.
num_cpu_blocks
=
0
return
num_gpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache.
"""
# Different values are not tested.
assert
num_cpu_blocks
==
0
assert
num_gpu_blocks
==
self
.
scheduler_config
.
max_num_seqs
+
1
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
return
False
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
return
None
@
torch
.
inference_mode
()
def
prepare_worker_input
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
WorkerInput
:
return
WorkerInput
(
num_seq_groups
=
len
(
execute_model_req
.
seq_group_metadata_list
),
)
def
execute_worker
(
self
,
worker_input
:
WorkerInput
)
->
None
:
pass
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Determine the size in bytes of a cache block.
This is required for speculative decoding; it is not yet implemented.
"""
raise
NotImplementedError
def
init_distributed_environment
(
self
):
"""Neuron uses transformers-neuronx for tensor parallelism.
vLLM still needs the environment initialized when TP/PP > 1
"""
init_distributed_environment
(
world_size
=
1
,
rank
=
self
.
rank
,
local_rank
=
self
.
local_rank
,
distributed_init_method
=
self
.
distributed_init_method
,
backend
=
current_platform
.
dist_backend
,
)
ensure_model_parallel_initialized
(
1
,
1
,
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
if
current_platform
.
use_transformers_neuronx
():
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
does not support LoRA with Neuron Framework "
f
"Transformers NeuronX"
)
return
self
.
model_runner
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
current_platform
.
use_transformers_neuronx
():
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
does not support LoRA with Neuron Framework "
f
"Transformers NeuronX"
)
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
current_platform
.
use_transformers_neuronx
():
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
does not support LoRA with Neuron Framework "
f
"Transformers NeuronX"
)
return
self
.
model_runner
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
if
current_platform
.
use_transformers_neuronx
():
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
does not support LoRA with Neuron Framework "
f
"Transformers NeuronX"
)
return
self
.
model_runner
.
list_loras
()
vllm/worker/neuronx_distributed_model_runner.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
,
Set
import
torch
from
neuronx_distributed_inference.models.mllama.aspect_ratio_utils
import
(
get_all_supported_aspect_ratios
)
from
neuronx_distributed_inference.modules.generation.sampling
import
(
prepare_sampling_params
)
from
neuronx_distributed_inference.modules.lora_serving
import
(
LoraCheckpoint
,
LoraServingConfig
)
from
vllm.config
import
VllmConfig
from
vllm.entrypoints.openai.serving_models
import
LoRAModulePath
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.neuronx_distributed
import
(
_get_model_architecture
,
get_neuron_model
)
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.worker.neuron_model_runner
import
(
ModelInputForNeuron
,
NeuronModelRunner
)
logger
=
init_logger
(
__name__
)
class
NeuronxDistributedModelRunner
(
NeuronModelRunner
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
):
super
().
__init__
(
vllm_config
)
self
.
lora_checkpoint
=
None
self
.
model
=
None
self
.
lora_serving_config
=
None
@
staticmethod
def
_get_lora_paths_strings
(
lora_modules
:
List
[
LoRAModulePath
]):
if
not
lora_modules
:
return
None
return
{
_
.
get
(
"name"
):
_
.
get
(
"path"
)
for
_
in
lora_modules
}
def
_get_nxdi_lora_config
(
self
):
override_neuron_config
=
self
.
model_config
.
override_neuron_config
lora_modules
=
override_neuron_config
.
pop
(
"lora_modules"
,
None
)
target_modules
=
override_neuron_config
.
pop
(
"target_modules"
,
None
)
lora_ckpt_paths
=
self
.
_get_lora_paths_strings
(
lora_modules
)
if
self
.
lora_config
.
max_loras
<
len
(
lora_ckpt_paths
):
raise
ValueError
(
"Number of LoRAs (%s) exceeds maximum "
"allowed (%s)"
,
len
(
lora_ckpt_paths
),
self
.
lora_config
.
max_loras
)
return
LoraServingConfig
(
max_loras
=
self
.
lora_config
.
max_loras
,
max_lora_rank
=
self
.
lora_config
.
max_lora_rank
,
target_modules
=
target_modules
,
lora_ckpt_paths
=
lora_ckpt_paths
,
)
def
load_model
(
self
)
->
None
:
# Update LoRA config
if
self
.
lora_config
is
not
None
:
self
.
lora_serving_config
=
self
.
_get_nxdi_lora_config
()
self
.
lora_checkpoint
=
LoraCheckpoint
(
self
.
lora_serving_config
)
self
.
model
=
get_neuron_model
(
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
lora_serving_config
=
self
.
lora_serving_config
)
def
get_nxd_sampling_params
(
self
,
sampling_metadata
):
if
self
.
model
.
config
.
neuron_config
.
on_device_sampling_config
:
max_topk
=
(
self
.
model
.
config
.
neuron_config
.
on_device_sampling_config
.
global_topk
)
else
:
max_topk
=
self
.
model
.
config
.
vocab_size
top_k
=
[
1
]
*
self
.
scheduler_config
.
max_num_seqs
top_p
=
[
1.0
]
*
self
.
scheduler_config
.
max_num_seqs
temperature
=
[
1.0
]
*
self
.
scheduler_config
.
max_num_seqs
for
index
,
sequenceGroupToSample
in
enumerate
(
sampling_metadata
.
seq_groups
):
top_k
[
index
]
=
(
sequenceGroupToSample
.
sampling_params
.
top_k
if
sequenceGroupToSample
.
sampling_params
.
top_k
>
0
else
max_topk
)
top_p
[
index
]
=
sequenceGroupToSample
.
sampling_params
.
top_p
temperature
[
index
]
=
(
sequenceGroupToSample
.
sampling_params
.
temperature
)
sampling_params
=
prepare_sampling_params
(
batch_size
=
self
.
scheduler_config
.
max_num_seqs
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
return
sampling_params
def
get_multi_modal_data_neuron
(
self
,
input_images
):
raise
NotImplementedError
(
"need to restore multi-modal support"
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelInputForNeuron
,
kv_caches
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"NeuronModelRunner does not support multi-step execution."
)
if
_get_model_architecture
(
self
.
model
.
config
)
!=
"MllamaForConditionalGeneration"
:
return
super
().
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
,
num_steps
)
sampling_params
=
self
.
get_nxd_sampling_params
(
model_input
.
sampling_metadata
)
if
model_input
.
multi_modal_kwargs
.
get
(
'pixel_values'
)
is
not
None
:
hidden_states
=
self
.
model
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
seq_ids
=
model_input
.
input_block_ids
,
pixel_values
=
model_input
.
multi_modal_kwargs
.
get
(
'pixel_values'
),
aspect_ratios
=
model_input
.
multi_modal_kwargs
.
get
(
'aspect_ratios'
),
sampling_params
=
sampling_params
,
num_chunks
=
model_input
.
multi_modal_kwargs
.
get
(
'num_chunks'
),
has_image
=
model_input
.
multi_modal_kwargs
.
get
(
'has_image'
).
squeeze
(
1
),
)
else
:
bs
=
model_input
.
input_tokens
.
shape
[
0
]
if
(
model_input
.
input_tokens
is
not
None
)
else
1
empty_pixel_values
=
torch
.
zeros
([
bs
,
1
,
4
,
3
,
560
,
560
],
dtype
=
torch
.
bfloat16
)
empty_aspect_ratios
=
torch
.
ones
([
bs
,
1
,
2
],
dtype
=
torch
.
int64
)
num_chunks
=
torch
.
zeros
((
bs
,
1
),
dtype
=
torch
.
int32
)
has_image
=
torch
.
zeros
([
bs
],
dtype
=
torch
.
int32
)
hidden_states
=
self
.
model
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
seq_ids
=
model_input
.
input_block_ids
,
pixel_values
=
empty_pixel_values
,
aspect_ratios
=
empty_aspect_ratios
,
sampling_params
=
sampling_params
,
num_chunks
=
num_chunks
,
has_image
=
has_image
,
)
output
=
self
.
model
.
sample
(
hidden_states
=
hidden_states
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
[
output
]
def
process_multi_modal_data_neuron
(
self
,
mm_data
):
# Neuron uses aspect_ratios instead of aspect_ratio_ids
all_supported_aspect_ratios
=
get_all_supported_aspect_ratios
(
self
.
model
.
config
.
vision_config
.
max_num_tiles
)
aspect_ratio_ids
=
mm_data
.
get
(
"aspect_ratio_ids"
)
mm_data
[
"aspect_ratios"
]
=
torch
.
tensor
(
all_supported_aspect_ratios
[
aspect_ratio_ids
]).
unsqueeze
(
0
)
# Neuron's num_chunks is HF's num_tiles
mm_data
[
"num_chunks"
]
=
mm_data
.
get
(
"num_tiles"
)
# Input has an image if it has pixel_values
bs
=
mm_data
[
"num_chunks"
].
shape
[
0
]
pixel_values
=
mm_data
.
get
(
"pixel_values"
)
if
pixel_values
is
not
None
and
not
torch
.
all
(
pixel_values
==
0
):
mm_data
[
"has_image"
]
=
torch
.
ones
(
bs
)
else
:
mm_data
[
"has_image"
]
=
torch
.
zeros
(
bs
)
return
mm_data
def
_get_lora_adapter_ids
(
self
,
seq_group_metadata_list
):
# set LoRA adapter IDs for multi-lora serving
batch_size
=
len
(
seq_group_metadata_list
)
if
self
.
lora_checkpoint
is
not
None
:
# "0" indicates NxDI to use the base model for inference
adapter_ids
=
[
"0"
]
*
batch_size
for
idx
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
if
seq_group_metadata
.
lora_request
is
not
None
:
adapter_ids
[
idx
]
=
seq_group_metadata
.
lora_request
.
lora_name
# convert adapter_ids from strings to integers
adapter_ids
=
self
.
lora_checkpoint
.
convert_adapter_ids_to_indices
(
adapter_ids
,
batch_size
)
else
:
adapter_ids
=
torch
.
zeros
((
batch_size
),
dtype
=
torch
.
int32
)
return
adapter_ids
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForNeuron
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_block_ids
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
seq_lens
=
None
if
not
self
.
_on_device_sampling_disabled
:
for
seq_group_metadata
in
seq_group_metadata_list
:
sampling_params
=
seq_group_metadata
.
sampling_params
top_k
,
top_p
,
temperature
=
(
self
.
_convert_to_neuron_sampling_params
(
sampling_params
))
sampling_params
.
top_k
=
top_k
sampling_params
.
top_p
=
top_p
sampling_params
.
temperature
=
temperature
# we need multi_modal_data for later tokens as well
multi_modal_kwargs_list
:
List
[
MultiModalKwargs
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
multi_modal_kwargs_list
.
append
(
mm_data
)
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
lora_adapter_ids
=
self
.
_get_lora_adapter_ids
(
seq_group_metadata_list
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens
,
self
.
device
,
self
.
pin_memory
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
return
ModelInputForNeuron
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
input_block_ids
=
input_block_ids
,
sampling_metadata
=
sampling_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
adapter_ids
=
lora_adapter_ids
)
def
remove_all_loras
(
self
):
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
):
logger
.
warning
(
"Adding LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config. If you supplied "
"the parameter, you can ignore this warning. Ignoring"
"lora request: "
,
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment