Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
021f76e4
Unverified
Commit
021f76e4
authored
Jun 11, 2025
by
Lifu Huang
Committed by
GitHub
Jun 11, 2025
Browse files
[Perf] Refactor LoRAManager to eliminate stream syncs and redundant computations (#6994)
parent
777688b8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
39 deletions
+83
-39
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+79
-34
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+4
-5
No files found.
python/sglang/srt/lora/lora_manager.py
View file @
021f76e4
...
...
@@ -81,7 +81,7 @@ class LoRAManager:
seg_indptr
=
torch
.
zeros
(
self
.
max_bs_in_cuda_graph
+
1
,
dtype
=
torch
.
int32
),
max_len
=
0
,
max_len
=
1
,
weight_indices
=
torch
.
zeros
(
self
.
max_bs_in_cuda_graph
,
dtype
=
torch
.
int32
),
...
...
@@ -89,6 +89,17 @@ class LoRAManager:
scalings
=
torch
.
zeros
(
self
.
max_loras_per_batch
,
dtype
=
torch
.
float
),
)
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
self
.
cuda_graph_batch_info
.
seg_lens
[:
self
.
max_bs_in_cuda_graph
].
fill_
(
1
)
torch
.
cumsum
(
self
.
cuda_graph_batch_info
.
seg_lens
[:
self
.
max_bs_in_cuda_graph
],
dim
=
0
,
out
=
self
.
cuda_graph_batch_info
.
seg_indptr
[
1
:
self
.
max_bs_in_cuda_graph
+
1
],
)
def
init_loras
(
self
):
# Config of each LoRA adapter
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
...
...
@@ -159,6 +170,45 @@ class LoRAManager:
# set up batch info shared by all lora modules
bs
=
forward_batch
.
batch_size
def
transfer_adapter_info
(
weight_indices_out
:
torch
.
Tensor
,
lora_ranks_out
:
torch
.
Tensor
,
scalings_out
:
torch
.
Tensor
,
):
"""
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
to device (CUDA) asynchronously.
"""
weight_indices
=
[
0
]
*
len
(
forward_batch
.
lora_paths
)
lora_ranks
=
[
0
]
*
self
.
max_loras_per_batch
scalings
=
[
0
]
*
self
.
max_loras_per_batch
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
weight_indices
[
i
]
=
self
.
memory_pool
.
get_buffer_id
(
lora_path
)
if
lora_path
is
not
None
:
lora
=
self
.
loras
[
lora_path
]
lora_ranks
[
weight_indices
[
i
]]
=
lora
.
config
.
hf_config
[
"r"
]
scalings
[
weight_indices
[
i
]]
=
lora
.
scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor
=
torch
.
tensor
(
weight_indices
,
dtype
=
torch
.
int32
,
pin_memory
=
True
,
device
=
"cpu"
)
lora_ranks_tensor
=
torch
.
tensor
(
lora_ranks
,
dtype
=
torch
.
int32
,
pin_memory
=
True
,
device
=
"cpu"
)
scalings_tensor
=
torch
.
tensor
(
scalings
,
dtype
=
torch
.
float
,
pin_memory
=
True
,
device
=
"cpu"
)
# Copy to device tensors asynchronously
weight_indices_out
[:
bs
].
copy_
(
weight_indices_tensor
,
non_blocking
=
True
)
lora_ranks_out
[:
self
.
max_loras_per_batch
].
copy_
(
lora_ranks_tensor
,
non_blocking
=
True
)
scalings_out
[:
self
.
max_loras_per_batch
].
copy_
(
scalings_tensor
,
non_blocking
=
True
)
if
(
hasattr
(
self
,
"max_bs_in_cuda_graph"
)
and
bs
<=
self
.
max_bs_in_cuda_graph
...
...
@@ -166,51 +216,46 @@ class LoRAManager:
):
# Do in-place updates when CUDA graph is enabled and the batch forward mode
# could use CUDA graph.
self
.
cuda_graph_batch_info
.
bs
=
bs
self
.
cuda_graph_batch_info
.
seg_lens
[:
bs
].
fill_
(
1
)
torch
.
cumsum
(
self
.
cuda_graph_batch_info
.
seg_lens
[:
bs
],
dim
=
0
,
out
=
self
.
cuda_graph_batch_info
.
seg_indptr
[
1
:
bs
+
1
],
)
self
.
cuda_graph_batch_info
.
max_len
=
1
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
self
.
cuda_graph_batch_info
.
weight_indices
[
i
]
=
(
self
.
memory_pool
.
get_buffer_id
(
lora_path
)
transfer_adapter_info
(
self
.
cuda_graph_batch_info
.
weight_indices
,
self
.
cuda_graph_batch_info
.
lora_ranks
,
self
.
cuda_graph_batch_info
.
scalings
,
)
if
lora_path
is
not
None
:
lora
=
self
.
loras
[
lora_path
]
self
.
cuda_graph_batch_info
.
lora_ranks
[
self
.
cuda_graph_batch_info
.
weight_indices
[
i
]
]
=
lora
.
config
.
hf_config
[
"r"
]
self
.
cuda_graph_batch_info
.
scalings
[
self
.
cuda_graph_batch_info
.
weight_indices
[
i
]
]
=
lora
.
scaling
self
.
cuda_graph_batch_info
.
bs
=
bs
self
.
cuda_graph_batch_info
.
max_len
=
1
batch_info
=
self
.
cuda_graph_batch_info
else
:
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
lora_ranks
=
torch
.
zeros
(
(
self
.
max_loras_per_batch
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
scalings
=
torch
.
zeros
(
(
self
.
max_loras_per_batch
,),
dtype
=
torch
.
float
,
device
=
self
.
device
)
transfer_adapter_info
(
weight_indices
,
lora_ranks
,
scalings
,
)
seg_lens
=
(
forward_batch
.
extend_seq_lens
if
forward_batch
.
forward_mode
.
is_extend
()
else
torch
.
ones
(
bs
,
device
=
self
.
device
)
)
max_len
=
(
# Calculate max_len from the CPU copy to avoid D2H transfer.
max
(
forward_batch
.
extend_seq_lens_cpu
)
if
forward_batch
.
forward_mode
.
is_extend
()
else
1
)
seg_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
seg_indptr
[
1
:]
=
torch
.
cumsum
(
seg_lens
,
dim
=
0
)
max_len
=
int
(
torch
.
max
(
seg_lens
))
weight_indices
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
lora_ranks
=
torch
.
zeros
(
(
self
.
max_loras_per_batch
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
scalings
=
torch
.
zeros
(
(
self
.
max_loras_per_batch
,),
dtype
=
torch
.
float
,
device
=
"cuda"
)
for
i
,
lora_path
in
enumerate
(
forward_batch
.
lora_paths
):
weight_indices
[
i
]
=
self
.
memory_pool
.
get_buffer_id
(
lora_path
)
if
lora_path
is
not
None
:
lora
=
self
.
loras
[
lora_path
]
lora_ranks
[
weight_indices
[
i
]]
=
lora
.
config
.
hf_config
[
"r"
]
scalings
[
weight_indices
[
i
]]
=
lora
.
scaling
batch_info
=
LoRABatchInfo
(
bs
=
bs
,
seg_lens
=
seg_lens
,
...
...
python/sglang/srt/lora/mem_pool.py
View file @
021f76e4
...
...
@@ -132,12 +132,13 @@ class LoRAMemoryPool:
for
buffer_id
in
range
(
self
.
max_loras_per_batch
):
# Prioritize empty slots
if
self
.
buffer_id_to_uid
[
buffer_id
]
==
""
:
return
buffer_id
,
""
return
buffer_id
for
buffer_id
in
range
(
self
.
max_loras_per_batch
):
# Evict unneeded lora
if
self
.
buffer_id_to_uid
[
buffer_id
]
not
in
cur_uids
:
return
buffer_id
,
self
.
buffer_id_to_uid
[
buffer_id
]
self
.
uid_to_buffer_id
.
pop
(
self
.
buffer_id_to_uid
[
buffer_id
])
return
buffer_id
raise
ValueError
(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
...
...
@@ -145,9 +146,7 @@ class LoRAMemoryPool:
for
uid
in
cur_uids
:
if
uid
not
in
self
.
uid_to_buffer_id
:
buffer_id
,
evicted_lora_uid
=
get_available_buffer_slot
()
if
evicted_lora_uid
!=
""
:
self
.
uid_to_buffer_id
.
pop
(
evicted_lora_uid
)
buffer_id
=
get_available_buffer_slot
()
self
.
load_lora_weight_to_buffer
(
uid
,
buffer_id
,
lora_adapters
.
get
(
uid
,
None
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment