Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a9b53dd4
Unverified
Commit
a9b53dd4
authored
Jan 25, 2026
by
Woosuk Kwon
Committed by
GitHub
Jan 25, 2026
Browse files
[Model Runner V2] Add LoRAState to consolidate lora logic (#33062)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
254db42e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
42 deletions
+53
-42
vllm/v1/worker/gpu/lora_utils.py
vllm/v1/worker/gpu/lora_utils.py
+47
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+6
-2
vllm/v1/worker/gpu/states.py
vllm/v1/worker/gpu/states.py
+0
-40
No files found.
vllm/v1/worker/gpu/lora_utils.py
0 → 100644
View file @
a9b53dd4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
from
vllm.lora.request
import
LoRARequest
NO_LORA_ID
=
0
class
LoraState
:
def
__init__
(
self
,
max_num_reqs
:
int
):
self
.
lora_ids
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
# req_id -> lora_request
self
.
lora_requests
:
dict
[
str
,
LoRARequest
]
=
{}
def
add_request
(
self
,
req_id
:
str
,
req_index
:
int
,
lora_request
:
LoRARequest
|
None
,
)
->
None
:
if
lora_request
is
not
None
:
self
.
lora_requests
[
req_id
]
=
lora_request
self
.
lora_ids
[
req_index
]
=
lora_request
.
lora_int_id
else
:
self
.
lora_ids
[
req_index
]
=
NO_LORA_ID
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
lora_requests
.
pop
(
req_id
,
None
)
def
make_lora_inputs
(
self
,
req_ids
:
list
[
str
],
idx_mapping
:
np
.
ndarray
,
num_scheduled_tokens
:
np
.
ndarray
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
set
[
LoRARequest
]]:
lora_ids
=
self
.
lora_ids
[
idx_mapping
]
prompt_lora_mapping
=
tuple
(
lora_ids
)
token_lora_mapping
=
tuple
(
lora_ids
.
repeat
(
num_scheduled_tokens
))
active_lora_requests
:
set
[
LoRARequest
]
=
set
()
for
req_id
in
req_ids
:
lora_request
=
self
.
lora_requests
.
get
(
req_id
,
None
)
if
lora_request
is
not
None
:
active_lora_requests
.
add
(
lora_request
)
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
vllm/v1/worker/gpu/model_runner.py
View file @
a9b53dd4
...
@@ -51,6 +51,7 @@ from vllm.v1.worker.gpu.kv_connector import (
...
@@ -51,6 +51,7 @@ from vllm.v1.worker.gpu.kv_connector import (
KVConnector
,
KVConnector
,
get_kv_connector
,
get_kv_connector
,
)
)
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
...
@@ -168,6 +169,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -168,6 +169,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_logits
=
self
.
max_num_reqs
*
(
self
.
num_speculative_steps
+
1
),
max_num_logits
=
self
.
max_num_reqs
*
(
self
.
num_speculative_steps
+
1
),
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
)
)
# LoRA-related workers.
self
.
lora_state
=
LoraState
(
max_num_reqs
=
self
.
max_num_reqs
)
# Buffers for CPU-to-GPU copies.
# Buffers for CPU-to-GPU copies.
self
.
tmp_idx_mapping
=
UvaBufferPool
(
self
.
max_num_reqs
,
torch
.
int32
)
self
.
tmp_idx_mapping
=
UvaBufferPool
(
self
.
max_num_reqs
,
torch
.
int32
)
...
@@ -426,6 +429,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -426,6 +429,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
self
.
encoder_runner
.
remove_request
(
req_id
)
self
.
encoder_runner
.
remove_request
(
req_id
)
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
self
.
lora_state
.
remove_request
(
req_id
)
def
free_states
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
def
free_states
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
:
...
@@ -444,7 +448,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -444,7 +448,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prompt_len
=
prompt_len
,
prompt_len
=
prompt_len
,
prefill_token_ids
=
new_req_data
.
prefill_token_ids
,
prefill_token_ids
=
new_req_data
.
prefill_token_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
lora_request
=
new_req_data
.
lora_request
,
)
)
req_index
=
self
.
req_states
.
req_id_to_index
[
req_id
]
req_index
=
self
.
req_states
.
req_id_to_index
[
req_id
]
...
@@ -469,6 +472,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -469,6 +472,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
prompt_logprobs_worker
.
add_request
(
self
.
prompt_logprobs_worker
.
add_request
(
req_id
,
req_index
,
new_req_data
.
sampling_params
req_id
,
req_index
,
new_req_data
.
sampling_params
)
)
self
.
lora_state
.
add_request
(
req_id
,
req_index
,
new_req_data
.
lora_request
)
if
scheduler_output
.
scheduled_new_reqs
:
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
apply_staged_writes
()
self
.
req_states
.
apply_staged_writes
()
...
@@ -841,7 +845,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -841,7 +845,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
if
self
.
lora_config
:
if
self
.
lora_config
:
# Activate LoRA adapters.
# Activate LoRA adapters.
lora_inputs
=
self
.
req
_state
s
.
make_lora_inputs
(
lora_inputs
=
self
.
lora
_state
.
make_lora_inputs
(
input_batch
.
req_ids
,
input_batch
.
req_ids
,
input_batch
.
idx_mapping_np
,
input_batch
.
idx_mapping_np
,
input_batch
.
num_scheduled_tokens
,
input_batch
.
num_scheduled_tokens
,
...
...
vllm/v1/worker/gpu/states.py
View file @
a9b53dd4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
from
vllm.v1.worker.gpu.buffer_utils
import
StagedWriteTensor
,
UvaBackedTensor
NO_LORA_ID
=
0
class
RequestState
:
class
RequestState
:
def
__init__
(
def
__init__
(
...
@@ -31,7 +26,6 @@ class RequestState:
...
@@ -31,7 +26,6 @@ class RequestState:
self
.
req_id_to_index
:
dict
[
str
,
int
]
=
{}
self
.
req_id_to_index
:
dict
[
str
,
int
]
=
{}
self
.
index_to_req_id
:
dict
[
int
,
str
]
=
{}
self
.
index_to_req_id
:
dict
[
int
,
str
]
=
{}
self
.
free_indices
=
list
(
range
(
max_num_reqs
))
self
.
free_indices
=
list
(
range
(
max_num_reqs
))
self
.
extra_data
:
dict
[
str
,
ExtraData
]
=
{}
self
.
prompt_len
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
prompt_len
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
...
@@ -70,10 +64,6 @@ class RequestState:
...
@@ -70,10 +64,6 @@ class RequestState:
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
self
.
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
)
# LoRA.
self
.
lora_ids
=
np
.
zeros
(
self
.
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
lora_ids
.
fill
(
NO_LORA_ID
)
@
property
@
property
def
num_reqs
(
self
)
->
int
:
def
num_reqs
(
self
)
->
int
:
return
len
(
self
.
req_id_to_index
)
return
len
(
self
.
req_id_to_index
)
...
@@ -84,13 +74,11 @@ class RequestState:
...
@@ -84,13 +74,11 @@ class RequestState:
prompt_len
:
int
,
prompt_len
:
int
,
prefill_token_ids
:
list
[
int
],
prefill_token_ids
:
list
[
int
],
num_computed_tokens
:
int
,
num_computed_tokens
:
int
,
lora_request
:
LoRARequest
|
None
,
)
->
None
:
)
->
None
:
assert
len
(
self
.
free_indices
)
>
0
,
"No free indices"
assert
len
(
self
.
free_indices
)
>
0
,
"No free indices"
req_idx
=
self
.
free_indices
.
pop
()
req_idx
=
self
.
free_indices
.
pop
()
self
.
req_id_to_index
[
req_id
]
=
req_idx
self
.
req_id_to_index
[
req_id
]
=
req_idx
self
.
index_to_req_id
[
req_idx
]
=
req_id
self
.
index_to_req_id
[
req_idx
]
=
req_id
self
.
extra_data
[
req_id
]
=
ExtraData
(
lora_request
)
self
.
prompt_len
[
req_idx
]
=
prompt_len
self
.
prompt_len
[
req_idx
]
=
prompt_len
prefill_len
=
len
(
prefill_token_ids
)
prefill_len
=
len
(
prefill_token_ids
)
...
@@ -102,43 +90,15 @@ class RequestState:
...
@@ -102,43 +90,15 @@ class RequestState:
self
.
num_computed_prefill_tokens
[
req_idx
]
=
num_computed_tokens
self
.
num_computed_prefill_tokens
[
req_idx
]
=
num_computed_tokens
self
.
num_computed_tokens
.
stage_write_elem
(
req_idx
,
num_computed_tokens
)
self
.
num_computed_tokens
.
stage_write_elem
(
req_idx
,
num_computed_tokens
)
if
lora_request
is
not
None
:
self
.
lora_ids
[
req_idx
]
=
lora_request
.
lora_int_id
else
:
self
.
lora_ids
[
req_idx
]
=
NO_LORA_ID
def
apply_staged_writes
(
self
)
->
None
:
def
apply_staged_writes
(
self
)
->
None
:
self
.
prefill_len
.
copy_to_uva
()
self
.
prefill_len
.
copy_to_uva
()
self
.
prefill_token_ids
.
apply_write
()
self
.
prefill_token_ids
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
self
.
num_computed_tokens
.
apply_write
()
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
def
remove_request
(
self
,
req_id
:
str
)
->
None
:
self
.
extra_data
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
req_idx
=
self
.
req_id_to_index
.
pop
(
req_id
,
None
)
if
req_idx
is
None
:
if
req_idx
is
None
:
# Request not found.
# Request not found.
return
return
self
.
index_to_req_id
.
pop
(
req_idx
,
None
)
self
.
index_to_req_id
.
pop
(
req_idx
,
None
)
self
.
free_indices
.
append
(
req_idx
)
self
.
free_indices
.
append
(
req_idx
)
def
make_lora_inputs
(
self
,
req_ids
:
list
[
str
],
idx_mapping
:
np
.
ndarray
,
num_scheduled_tokens
:
np
.
ndarray
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
set
[
LoRARequest
]]:
lora_ids
=
self
.
lora_ids
[
idx_mapping
]
prompt_lora_mapping
=
tuple
(
lora_ids
)
token_lora_mapping
=
tuple
(
lora_ids
.
repeat
(
num_scheduled_tokens
))
active_lora_requests
:
set
[
LoRARequest
]
=
set
()
for
req_id
in
req_ids
:
lora_request
=
self
.
extra_data
[
req_id
].
lora_request
if
lora_request
is
not
None
:
active_lora_requests
.
add
(
lora_request
)
return
prompt_lora_mapping
,
token_lora_mapping
,
active_lora_requests
@
dataclass
class
ExtraData
:
lora_request
:
LoRARequest
|
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment