Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
972eddf7
Unverified
Commit
972eddf7
authored
May 29, 2025
by
Satyajith Chilappagari
Committed by
GitHub
May 29, 2025
Browse files
[Neuron] Add multi-LoRA support for Neuron. (#18284)
Signed-off-by:
Satyajith Chilappagari
<
satchill@amazon.com
>
parent
fd7bb88d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
343 additions
and
26 deletions
+343
-26
tests/neuron/2_core/test_multi_lora.py
tests/neuron/2_core/test_multi_lora.py
+98
-0
vllm/model_executor/model_loader/neuronx_distributed.py
vllm/model_executor/model_loader/neuronx_distributed.py
+18
-13
vllm/platforms/neuron.py
vllm/platforms/neuron.py
+0
-3
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+31
-1
vllm/worker/neuron_worker.py
vllm/worker/neuron_worker.py
+38
-4
vllm/worker/neuronx_distributed_model_runner.py
vllm/worker/neuronx_distributed_model_runner.py
+158
-5
No files found.
tests/neuron/2_core/test_multi_lora.py
0 → 100644
View file @
972eddf7
# SPDX-License-Identifier: Apache-2.0
from
huggingface_hub
import
snapshot_download
from
vllm
import
LLM
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
def
test_llama_single_lora
():
sql_lora_files
=
snapshot_download
(
repo_id
=
"yard1/llama-2-7b-sql-lora-test"
)
llm
=
LLM
(
model
=
"meta-llama/Llama-2-7b-hf"
,
tensor_parallel_size
=
2
,
max_num_seqs
=
4
,
max_model_len
=
512
,
use_v2_block_manager
=
True
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
"skip_warmup"
:
True
,
"lora_modules"
:
[{
"name"
:
"lora_id_1"
,
"path"
:
sql_lora_files
}]
},
enable_lora
=
True
,
max_loras
=
1
,
max_lora_rank
=
256
,
device
=
"neuron"
)
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1
=
LoRARequest
(
"lora_id_1"
,
0
,
" "
)
prompts
=
[
"The president of the United States is"
,
"The capital of France is"
,
]
outputs
=
llm
.
generate
(
prompts
,
SamplingParams
(
top_k
=
1
),
lora_request
=
[
lora_req_1
,
lora_req_1
])
expected_outputs
=
[
" the head of state and head of government of the United States. "
"The president direct"
,
" a city of contrasts. The city is home to the Eiffel Tower"
]
for
expected_output
,
output
in
zip
(
expected_outputs
,
outputs
):
generated_text
=
output
.
outputs
[
0
].
text
assert
(
expected_output
==
generated_text
)
def
test_llama_multiple_lora
():
sql_lora_files
=
snapshot_download
(
repo_id
=
"yard1/llama-2-7b-sql-lora-test"
)
llm
=
LLM
(
model
=
"meta-llama/Llama-2-7b-hf"
,
tensor_parallel_size
=
2
,
max_num_seqs
=
4
,
max_model_len
=
512
,
use_v2_block_manager
=
True
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
"skip_warmup"
:
True
,
"lora_modules"
:
[{
"name"
:
"lora_id_1"
,
"path"
:
sql_lora_files
},
{
"name"
:
"lora_id_2"
,
"path"
:
sql_lora_files
}]
},
enable_lora
=
True
,
max_loras
=
2
,
max_lora_rank
=
256
,
device
=
"neuron"
)
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1
=
LoRARequest
(
"lora_id_1"
,
0
,
" "
)
lora_req_2
=
LoRARequest
(
"lora_id_2"
,
1
,
" "
)
prompts
=
[
"The president of the United States is"
,
"The capital of France is"
,
]
outputs
=
llm
.
generate
(
prompts
,
SamplingParams
(
top_k
=
1
),
lora_request
=
[
lora_req_1
,
lora_req_2
])
expected_outputs
=
[
" the head of state and head of government of the United States. "
"The president direct"
,
" a city of contrasts. The city is home to the Eiffel Tower"
]
for
expected_output
,
output
in
zip
(
expected_outputs
,
outputs
):
generated_text
=
output
.
outputs
[
0
].
text
assert
(
expected_output
==
generated_text
)
vllm/model_executor/model_loader/neuronx_distributed.py
View file @
972eddf7
...
@@ -17,6 +17,8 @@ from neuronx_distributed_inference.models.config import (
...
@@ -17,6 +17,8 @@ from neuronx_distributed_inference.models.config import (
FusedSpecNeuronConfig
,
OnDeviceSamplingConfig
)
FusedSpecNeuronConfig
,
OnDeviceSamplingConfig
)
from
neuronx_distributed_inference.models.mllama.utils
import
(
from
neuronx_distributed_inference.models.mllama.utils
import
(
create_vision_mask
)
create_vision_mask
)
from
neuronx_distributed_inference.modules.lora_serving
import
(
LoraServingConfig
)
from
neuronx_distributed_inference.utils.hf_adapter
import
(
from
neuronx_distributed_inference.utils.hf_adapter
import
(
load_pretrained_config
)
load_pretrained_config
)
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
PretrainedConfig
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
PretrainedConfig
...
@@ -80,25 +82,26 @@ class NeuronCausalLM(nn.Module):
...
@@ -80,25 +82,26 @@ class NeuronCausalLM(nn.Module):
# Lazy initialized
# Lazy initialized
self
.
model
:
nn
.
Module
self
.
model
:
nn
.
Module
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_id
s
:
torch
.
Tensor
,
position
s
:
torch
.
Tensor
,
position
s
:
torch
.
Tensor
,
input_block_id
s
:
torch
.
Tensor
,
input_block_id
s
:
torch
.
Tensor
,
sampling_param
s
:
torch
.
Tensor
,
sampling_params
:
torch
.
Tensor
,
prev_hidden
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
adapter_ids
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# sort block ids sequentially for perf/neuron support reasons
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids
,
sorted_indices
=
torch
.
sort
(
input_block_ids
)
sorted_input_block_ids
,
sorted_indices
=
torch
.
sort
(
input_block_ids
)
input_ids
=
torch
.
index_select
(
input_ids
,
0
,
sorted_indices
)
input_ids
=
torch
.
index_select
(
input_ids
,
0
,
sorted_indices
)
positions
=
torch
.
index_select
(
positions
,
0
,
sorted_indices
)
positions
=
torch
.
index_select
(
positions
,
0
,
sorted_indices
)
sampling_params
=
torch
.
index_select
(
sampling_params
,
0
,
sampling_params
=
torch
.
index_select
(
sampling_params
,
0
,
sorted_indices
)
sorted_indices
)
output
=
self
.
model
(
input_ids
,
output
=
self
.
model
(
input_ids
,
attention_mask
=
None
,
attention_mask
=
None
,
position_ids
=
positions
,
position_ids
=
positions
,
seq_ids
=
sorted_input_block_ids
,
seq_ids
=
sorted_input_block_ids
,
sampling_params
=
sampling_params
)
sampling_params
=
sampling_params
,
prev_hidden
=
prev_hidden
,
adapter_ids
=
adapter_ids
)
# on-device sampling
# on-device sampling
if
self
.
config
.
neuron_config
.
on_device_sampling_config
:
if
self
.
config
.
neuron_config
.
on_device_sampling_config
:
output
=
output
.
hidden_states
output
=
output
.
hidden_states
...
@@ -522,7 +525,8 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
...
@@ -522,7 +525,8 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
def
_get_default_neuron_config
(
model_config
:
ModelConfig
,
def
_get_default_neuron_config
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
):
scheduler_config
:
SchedulerConfig
,
lora_serving_config
:
LoraServingConfig
):
"""Generate a neuron config based on vllm config args."""
"""Generate a neuron config based on vllm config args."""
on_device_sampling_config
=
OnDeviceSamplingConfig
(
dynamic
=
True
,
on_device_sampling_config
=
OnDeviceSamplingConfig
(
dynamic
=
True
,
deterministic
=
False
)
deterministic
=
False
)
...
@@ -541,7 +545,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
...
@@ -541,7 +545,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
padding_side
=
"right"
,
padding_side
=
"right"
,
on_device_sampling_config
=
on_device_sampling_config
,
on_device_sampling_config
=
on_device_sampling_config
,
sequence_parallel_enabled
=
True
,
sequence_parallel_enabled
=
True
,
)
lora_serving_config
=
lora_serving_config
)
return
neuron_config
return
neuron_config
...
@@ -581,7 +585,8 @@ def _get_neuron_config_after_override(default_neuron_config,
...
@@ -581,7 +585,8 @@ def _get_neuron_config_after_override(default_neuron_config,
def
get_neuron_model
(
model_config
:
ModelConfig
,
def
get_neuron_model
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
scheduler_config
:
SchedulerConfig
,
lora_serving_config
:
LoraServingConfig
)
->
nn
.
Module
:
"""Initializes a neuron-optimized model for inference."""
"""Initializes a neuron-optimized model for inference."""
model_arch
=
_get_model_architecture
(
model_config
.
hf_config
)
model_arch
=
_get_model_architecture
(
model_config
.
hf_config
)
if
model_arch
==
"MllamaForConditionalGeneration"
:
if
model_arch
==
"MllamaForConditionalGeneration"
:
...
@@ -589,7 +594,7 @@ def get_neuron_model(model_config: ModelConfig,
...
@@ -589,7 +594,7 @@ def get_neuron_model(model_config: ModelConfig,
else
:
else
:
model
=
NeuronCausalLM
(
model_config
.
hf_config
)
model
=
NeuronCausalLM
(
model_config
.
hf_config
)
default_neuron_config_args
=
_get_default_neuron_config
(
default_neuron_config_args
=
_get_default_neuron_config
(
model_config
,
parallel_config
,
scheduler_config
)
model_config
,
parallel_config
,
scheduler_config
,
lora_serving_config
)
neuron_config
=
_get_neuron_config_after_override
(
neuron_config
=
_get_neuron_config_after_override
(
default_neuron_config_args
,
model_config
.
override_neuron_config
)
default_neuron_config_args
,
model_config
.
override_neuron_config
)
...
...
vllm/platforms/neuron.py
View file @
972eddf7
...
@@ -49,9 +49,6 @@ class NeuronPlatform(Platform):
...
@@ -49,9 +49,6 @@ class NeuronPlatform(Platform):
if
parallel_config
.
world_size
>
1
:
if
parallel_config
.
world_size
>
1
:
parallel_config
.
distributed_executor_backend
=
"uni"
parallel_config
.
distributed_executor_backend
=
"uni"
assert
(
vllm_config
.
lora_config
is
None
),
"LoRA is not supported for Neuron backend."
if
vllm_config
.
cache_config
and
vllm_config
.
model_config
:
if
vllm_config
.
cache_config
and
vllm_config
.
model_config
:
# neuron needs block_size = max_model_len
# neuron needs block_size = max_model_len
vllm_config
.
cache_config
.
block_size
=
\
vllm_config
.
cache_config
.
block_size
=
\
...
...
vllm/worker/neuron_model_runner.py
View file @
972eddf7
...
@@ -2,13 +2,15 @@
...
@@ -2,13 +2,15 @@
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
from
vllm.model_executor.model_loader.neuron
import
get_neuron_model
...
@@ -36,6 +38,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
...
@@ -36,6 +38,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_block_ids
:
Optional
[
torch
.
Tensor
]
=
None
input_block_ids
:
Optional
[
torch
.
Tensor
]
=
None
sampling_metadata
:
SamplingMetadata
=
None
sampling_metadata
:
SamplingMetadata
=
None
multi_modal_kwargs
:
BatchedTensorInputs
=
None
multi_modal_kwargs
:
BatchedTensorInputs
=
None
adapter_ids
:
Optional
[
str
]
=
None
def
as_broadcastable_tensor_dict
(
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
...
@@ -80,6 +83,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -80,6 +83,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
"The model will run without sliding window."
)
"The model will run without sliding window."
)
self
.
device_config
=
(
self
.
device_config
if
self
.
device_config
self
.
device_config
=
(
self
.
device_config
if
self
.
device_config
is
not
None
else
DeviceConfig
())
is
not
None
else
DeviceConfig
())
self
.
lora_config
=
vllm_config
.
lora_config
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
...
@@ -378,6 +382,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -378,6 +382,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
input_block_ids
=
model_input
.
input_block_ids
,
input_block_ids
=
model_input
.
input_block_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
adapter_ids
=
model_input
.
adapter_ids
,
**
MultiModalKwargs
.
as_kwargs
(
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
dtype
=
self
.
model_config
.
dtype
,
...
@@ -416,3 +421,28 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -416,3 +421,28 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
@
property
@
property
def
vocab_size
(
self
)
->
int
:
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
return
self
.
model_config
.
get_vocab_size
()
def
remove_all_loras
(
self
):
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
):
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
(
"LoRAs are not supported for Transformers NeuronX framework"
)
vllm/worker/neuron_worker.py
View file @
972eddf7
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""A Neuron worker class."""
"""A Neuron worker class."""
import
os
import
os
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch.distributed
import
torch.distributed
...
@@ -9,19 +9,19 @@ from vllm.config import VllmConfig
...
@@ -9,19 +9,19 @@ from vllm.config import VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.neuron
import
NeuronFramework
from
vllm.platforms.neuron
import
NeuronFramework
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.neuron_model_runner
import
NeuronModelRunner
from
vllm.worker.neuron_model_runner
import
NeuronModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
WorkerBase
,
LoRANotSupportedWorkerBase
,
WorkerBase
,
WorkerInput
)
WorkerInput
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
NeuronWorker
(
LoRANotSupportedWorkerBase
,
LocalOrDistributedWorkerBase
):
class
NeuronWorker
(
LocalOrDistributedWorkerBase
):
"""A worker class that executes the model on a group of neuron cores.
"""A worker class that executes the model on a group of neuron cores.
"""
"""
...
@@ -38,6 +38,7 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -38,6 +38,7 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
rank
=
rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
self
.
lora_config
=
vllm_config
.
lora_config
if
self
.
model_config
.
trust_remote_code
:
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
# note: lazy import to avoid importing torch before initializing
...
@@ -59,6 +60,9 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -59,6 +60,9 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"[transformers-neuronx, neuronx-distributed-inference]"
)
"[transformers-neuronx, neuronx-distributed-inference]"
)
def
get_tnx_model_runner
(
self
,
vllm_config
):
def
get_tnx_model_runner
(
self
,
vllm_config
):
assert
(
self
.
lora_config
is
None
),
(
"LoRA is not supported for TransformersNeuronX "
"framework."
)
from
vllm.worker.multi_step_neuron_model_runner
import
(
from
vllm.worker.multi_step_neuron_model_runner
import
(
MultiStepNeuronModelRunner
)
MultiStepNeuronModelRunner
)
if
self
.
speculative_config
is
not
None
:
if
self
.
speculative_config
is
not
None
:
...
@@ -72,6 +76,8 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -72,6 +76,8 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
from
vllm.worker.neuronx_distributed_model_runner
import
(
from
vllm.worker.neuronx_distributed_model_runner
import
(
NeuronxDistributedModelRunner
)
NeuronxDistributedModelRunner
)
if
self
.
speculative_config
is
not
None
:
if
self
.
speculative_config
is
not
None
:
assert
(
self
.
lora_config
is
None
),
"LoRA is not supported for Speculative Decoding"
return
MultiStepNeuronxDistributedModelRunner
(
return
MultiStepNeuronxDistributedModelRunner
(
vllm_config
=
vllm_config
)
vllm_config
=
vllm_config
)
else
:
else
:
...
@@ -156,3 +162,31 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -156,3 +162,31 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
1
,
1
,
1
,
1
,
)
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
if
current_platform
.
use_transformers_neuronx
():
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
does not support LoRA with Neuron Framework "
f
"Transformers NeuronX"
)
return
self
.
model_runner
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
current_platform
.
use_transformers_neuronx
():
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
does not support LoRA with Neuron Framework "
f
"Transformers NeuronX"
)
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
current_platform
.
use_transformers_neuronx
():
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
does not support LoRA with Neuron Framework "
f
"Transformers NeuronX"
)
return
self
.
model_runner
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
if
current_platform
.
use_transformers_neuronx
():
raise
NotImplementedError
(
f
"
{
type
(
self
)
}
does not support LoRA with Neuron Framework "
f
"Transformers NeuronX"
)
return
self
.
model_runner
.
list_loras
()
vllm/worker/neuronx_distributed_model_runner.py
View file @
972eddf7
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Set
import
torch
import
torch
from
neuronx_distributed_inference.modules.generation.sampling
import
(
from
neuronx_distributed_inference.modules.generation.sampling
import
(
prepare_sampling_params
)
prepare_sampling_params
)
from
neuronx_distributed_inference.modules.lora_serving
import
(
LoraCheckpoint
,
LoraServingConfig
)
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.entrypoints.openai.serving_models
import
LoRAModulePath
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.neuronx_distributed
import
(
from
vllm.model_executor.model_loader.neuronx_distributed
import
(
_get_model_architecture
,
get_neuron_model
)
_get_model_architecture
,
get_neuron_model
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.worker.neuron_model_runner
import
(
ModelInputForNeuron
,
from
vllm.worker.neuron_model_runner
import
(
ModelInputForNeuron
,
NeuronModelRunner
)
NeuronModelRunner
)
...
@@ -25,11 +32,44 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
...
@@ -25,11 +32,44 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
):
):
super
().
__init__
(
vllm_config
)
super
().
__init__
(
vllm_config
)
self
.
lora_checkpoint
=
None
self
.
model
=
None
self
.
lora_serving_config
=
None
@
staticmethod
def
_get_lora_paths_strings
(
lora_modules
:
List
[
LoRAModulePath
]):
if
not
lora_modules
:
return
None
return
{
_
.
get
(
"name"
):
_
.
get
(
"path"
)
for
_
in
lora_modules
}
def
_get_nxdi_lora_config
(
self
):
override_neuron_config
=
self
.
model_config
.
override_neuron_config
lora_modules
=
override_neuron_config
.
pop
(
"lora_modules"
,
None
)
target_modules
=
override_neuron_config
.
pop
(
"target_modules"
,
None
)
lora_ckpt_paths
=
self
.
_get_lora_paths_strings
(
lora_modules
)
if
self
.
lora_config
.
max_loras
<
len
(
lora_ckpt_paths
):
raise
ValueError
(
"Number of LoRAs (%s) exceeds maximum "
"allowed (%s)"
,
len
(
lora_ckpt_paths
),
self
.
lora_config
.
max_loras
)
return
LoraServingConfig
(
max_loras
=
self
.
lora_config
.
max_loras
,
max_lora_rank
=
self
.
lora_config
.
max_lora_rank
,
target_modules
=
target_modules
,
lora_ckpt_paths
=
lora_ckpt_paths
,
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_neuron_model
(
self
.
model_config
,
# Update LoRA config
parallel_config
=
self
.
parallel_config
,
if
self
.
lora_config
is
not
None
:
scheduler_config
=
self
.
scheduler_config
)
self
.
lora_serving_config
=
self
.
_get_nxdi_lora_config
()
self
.
lora_checkpoint
=
LoraCheckpoint
(
self
.
lora_serving_config
)
self
.
model
=
get_neuron_model
(
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
lora_serving_config
=
self
.
lora_serving_config
)
def
get_nxd_sampling_params
(
self
,
sampling_metadata
):
def
get_nxd_sampling_params
(
self
,
sampling_metadata
):
if
self
.
model
.
config
.
neuron_config
.
on_device_sampling_config
:
if
self
.
model
.
config
.
neuron_config
.
on_device_sampling_config
:
...
@@ -134,3 +174,116 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
...
@@ -134,3 +174,116 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
)
)
return
[
output
]
return
[
output
]
def
_get_lora_adapter_ids
(
self
,
seq_group_metadata_list
):
# set LoRA adapter IDs for multi-lora serving
batch_size
=
len
(
seq_group_metadata_list
)
if
self
.
lora_checkpoint
is
not
None
:
# "0" indicates NxDI to use the base model for inference
adapter_ids
=
[
"0"
]
*
batch_size
for
idx
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
if
seq_group_metadata
.
lora_request
is
not
None
:
adapter_ids
[
idx
]
=
seq_group_metadata
.
lora_request
.
lora_name
# convert adapter_ids from strings to integers
adapter_ids
=
self
.
lora_checkpoint
.
convert_adapter_ids_to_indices
(
adapter_ids
,
batch_size
)
else
:
adapter_ids
=
torch
.
zeros
((
batch_size
),
dtype
=
torch
.
int32
)
return
adapter_ids
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForNeuron
:
multi_modal_kwargs
=
None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_block_ids
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
seq_lens
=
None
if
not
self
.
_on_device_sampling_disabled
:
for
seq_group_metadata
in
seq_group_metadata_list
:
sampling_params
=
seq_group_metadata
.
sampling_params
top_k
,
top_p
,
temperature
=
(
self
.
_convert_to_neuron_sampling_params
(
sampling_params
))
sampling_params
.
top_k
=
top_k
sampling_params
.
top_p
=
top_p
sampling_params
.
temperature
=
temperature
lora_adapter_ids
=
self
.
_get_lora_adapter_ids
(
seq_group_metadata_list
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens
,
self
.
device
,
self
.
pin_memory
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
if
current_platform
.
use_transformers_neuronx
(
)
and
not
self
.
_on_device_sampling_disabled
:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids
=
[
seq_group_meta_data
.
request_id
for
seq_group_meta_data
in
seq_group_metadata_list
]
if
current_batch_request_ids
!=
self
.
_previous_batch_request_ids
:
self
.
_update_neuron_sampling_params
(
seq_group_metadata_list
)
self
.
_previous_batch_request_ids
=
current_batch_request_ids
return
ModelInputForNeuron
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
input_block_ids
=
input_block_ids
,
sampling_metadata
=
sampling_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
adapter_ids
=
lora_adapter_ids
)
def
remove_all_loras
(
self
):
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
):
logger
.
warning
(
"Adding LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config. If you supplied "
"the parameter, you can ignore this warning. Ignoring"
"lora request: "
,
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
(
"Managing LoRAs is only supported through the "
"lora_modules parameter in override_neuron_config"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment