Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e8e18dcd
Unverified
Commit
e8e18dcd
authored
May 12, 2025
by
Lianmin Zheng
Committed by
GitHub
May 12, 2025
Browse files
Revert "fix some typos" (#6244)
parent
bad7c26f
Changes
95
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
50 additions
and
50 deletions
+50
-50
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+3
-3
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+1
-1
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+1
-1
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+1
-1
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+1
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-1
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+5
-5
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+1
-1
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+3
-3
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-1
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+3
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+2
-2
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+2
-2
python/sglang/srt/lora/backend/base_backend.py
python/sglang/srt/lora/backend/base_backend.py
+8
-8
python/sglang/srt/lora/backend/flashinfer_backend.py
python/sglang/srt/lora/backend/flashinfer_backend.py
+1
-1
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+1
-1
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+4
-4
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+8
-8
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
e8e18dcd
...
...
@@ -20,7 +20,7 @@ class AttentionBackend(ABC):
raise
NotImplementedError
()
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Init the global shared states for
CUDA
graph."""
"""Init the global shared states for
cuda
graph."""
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
...
...
@@ -33,7 +33,7 @@ class AttentionBackend(ABC):
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
"""Init the metadata for a forward pass for capturing a
CUDA
graph."""
"""Init the metadata for a forward pass for capturing a
cuda
graph."""
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
"""Init the metadata for a forward pass for replaying a
CUDA
graph."""
"""Init the metadata for a forward pass for replaying a
cuda
graph."""
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
e8e18dcd
...
...
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
class
DoubleSparseAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of
CUDA
context
# Lazy import to avoid the initialization of
cuda
context
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
extend_attention_fwd
,
flash_decode_attention_fwd
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
e8e18dcd
...
...
@@ -664,7 +664,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr
=
kv_indptr
[:
bs
+
1
]
if
wrapper
.
is_cuda_graph_enabled
:
# Directly write to the
CUDA
graph input buffer
# Directly write to the
cuda
graph input buffer
kv_indices
=
wrapper
.
_paged_kv_indices_buf
else
:
kv_indices
=
torch
.
empty
(
...
...
@@ -1173,7 +1173,7 @@ def fast_decode_plan(
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the
CUDA
graph buffers.
- Remove unnecessary device-to-device copy for the
cuda
graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size
=
len
(
last_page_len
)
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
e8e18dcd
...
...
@@ -874,7 +874,7 @@ def fast_mla_decode_plan(
)
->
None
:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
CUDA
graph replaying.
cuda
graph replaying.
"""
self
.
_causal
=
causal
self
.
_page_size
=
page_size
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
e8e18dcd
...
...
@@ -92,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Lazy import to avoid the initialization of
CUDA
context
# Lazy import to avoid the initialization of
cuda
context
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
e8e18dcd
...
...
@@ -257,7 +257,7 @@ class VisionFlash3Attention(nn.Module):
**
kwargs
,
):
if
not
_is_cuda
:
raise
Exception
(
"VisionFlash3Attention is only available for
CUDA
"
)
raise
Exception
(
"VisionFlash3Attention is only available for
cuda
"
)
super
().
__init__
()
def
forward
(
...
...
python/sglang/srt/layers/dp_attention.py
View file @
e8e18dcd
...
...
@@ -237,7 +237,7 @@ def dp_scatter(
forward_batch
:
ForwardBatch
,
):
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for
CUDA
graph
# since local_tokens may be padded for
cuda
graph
local_start_pos
,
local_num_tokens
=
get_dp_local_info
(
forward_batch
)
local_tokens
.
fill_
(
0
)
...
...
python/sglang/srt/layers/logits_processor.py
View file @
e8e18dcd
...
...
@@ -166,7 +166,7 @@ class LogitsMetadata:
def
compute_dp_attention_metadata
(
self
,
hidden_states
:
torch
.
Tensor
):
if
self
.
global_num_tokens_for_logprob_cpu
is
None
:
# we are capturing
CUDA
graph
# we are capturing
cuda
graph
return
cumtokens
=
torch
.
cumsum
(
self
.
global_num_tokens_for_logprob_gpu
,
dim
=
0
)
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
e8e18dcd
...
...
@@ -38,7 +38,7 @@ try:
except
ImportError
:
VLLM_AVAILABLE
=
False
# Define empty classes as placeholders when v
LLM
is not available
# Define empty classes as placeholders when v
llm
is not available
class
DummyConfig
:
def
override_quantization_method
(
self
,
*
args
,
**
kwargs
):
return
None
...
...
@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if
quantization
in
VLLM_QUANTIZATION_METHODS
and
not
VLLM_AVAILABLE
:
raise
ValueError
(
f
"
{
quantization
}
quantization requires some operators from vllm. "
"Please install v
LLM
by `pip install vllm==0.8.4`"
"Please install v
llm
by `pip install vllm==0.8.4`"
)
return
QUANTIZATION_METHODS
[
quantization
]
...
...
@@ -231,7 +231,7 @@ original_isinstance = builtins.isinstance
def
monkey_patch_isinstance_for_vllm_base_layer
(
reverse
:
bool
=
False
):
"""
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize
SGL
ang layers
can recognize
sgl
ang layers
"""
if
not
VLLM_AVAILABLE
:
return
...
...
@@ -267,7 +267,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert
SGL
ang arguments to v
LLM
arguments.
Convert
sgl
ang arguments to v
llm
arguments.
"""
original_apply
=
class_obj
.
apply
sig
=
inspect
.
signature
(
original_apply
)
...
...
@@ -329,6 +329,6 @@ def monkey_patch_quant_configs():
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
# Only apply monkey patches if v
LLM
is available
# Only apply monkey patches if v
llm
is available
if
VLLM_AVAILABLE
:
monkey_patch_quant_configs
()
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
e8e18dcd
...
...
@@ -208,7 +208,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid
CUDA
graph capturing issue
# Use torch Parameter to avoid
cuda
graph capturing issue
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
e8e18dcd
...
...
@@ -363,7 +363,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"v
LLM
is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install v
LLM
"
"v
llm
is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install v
llm
"
)
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
...
...
@@ -409,7 +409,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"v
LLM
is not installed, to use CompressedTensorsW8A16Fp8, please install v
LLM
"
"v
llm
is not installed, to use CompressedTensorsW8A16Fp8, please install v
llm
"
)
is_static_input_scheme
=
input_quant
and
not
input_quant
.
dynamic
return
CompressedTensorsW8A16Fp8
(
...
...
@@ -491,7 +491,7 @@ class CompressedTensorsConfig(QuantizationConfig):
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"v
LLM
is not installed, to use CompressedTensors24, please install v
LLM
"
"v
llm
is not installed, to use CompressedTensors24, please install v
llm
"
)
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
e8e18dcd
...
...
@@ -65,7 +65,7 @@ class CompressedTensorsMoEMethod:
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"v
LLM
is not installed, to use CompressedTensorsWNA16MoEMethod, please install v
LLM
."
"v
llm
is not installed, to use CompressedTensorsWNA16MoEMethod, please install v
llm
."
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
e8e18dcd
...
...
@@ -27,10 +27,10 @@ except ImportError:
MARLIN_FP8_AVAILABLE
=
False
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
raise
ImportError
(
"v
LLM
is not installed"
)
raise
ImportError
(
"v
llm
is not installed"
)
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
raise
ImportError
(
"v
LLM
is not installed"
)
raise
ImportError
(
"v
llm
is not installed"
)
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
...
...
@@ -45,7 +45,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
if
not
MARLIN_FP8_AVAILABLE
:
raise
ImportError
(
"v
LLM
is not installed. To use CompressedTensorsW8A16Fp8, please install v
LLM
"
"v
llm
is not installed. To use CompressedTensorsW8A16Fp8, please install v
llm
"
)
@
classmethod
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
e8e18dcd
...
...
@@ -357,7 +357,7 @@ def apply_fp8_linear(
# Fused GEMM_DQ
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to v
LLM
cutlass w8a8 fp8 kernel
# Fall back to v
llm
cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
...
...
@@ -493,7 +493,7 @@ def apply_fp8_linear(
if
cutlass_fp8_supported
:
try
:
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to v
LLM
cutlass w8a8 fp8 kernel
# Fall back to v
llm
cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
e8e18dcd
...
...
@@ -186,8 +186,8 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per
LoRA adapter
, we can run multiple
LoRA adapters
in a batched way.
instead of running rotary embedding kernel per
lora
, we can run multiple
lora
in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
...
...
python/sglang/srt/lora/backend/base_backend.py
View file @
e8e18dcd
...
...
@@ -41,13 +41,13 @@ class BaseLoRABackend:
def
run_lora_a_sgemm
(
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Run segment Gemm of
LoRA
a modules with current backend.
"""Run segment Gemm of
lora
a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of
LoRA
weights with shape (num_lora, c * r, input_dim),
here r is
LoRA
rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
weights: a set of
lora
weights with shape (num_lora, c * r, input_dim),
here r is
lora
rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
Returns:
result with shape (s, c * r)
...
...
@@ -57,12 +57,12 @@ class BaseLoRABackend:
def
run_lora_b_sgemm
(
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Run segment Gemm of
LoRA
b modules with current backend.
"""Run segment Gemm of
lora
b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is
LoRA
rank
weights: a set of
LoRA
weights with shape (num_lora, output_dim, r)
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is
lora
rank
weights: a set of
lora
weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
...
...
@@ -77,7 +77,7 @@ class BaseLoRABackend:
*
args
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run the
LoRA
pass for QKV Layer.
"""Run the
lora
pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
...
...
@@ -100,7 +100,7 @@ class BaseLoRABackend:
*
args
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run the
LoRA
pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
"""Run the
lora
pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
...
...
python/sglang/srt/lora/backend/flashinfer_backend.py
View file @
e8e18dcd
...
...
@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
dtype
=
x
.
dtype
,
)
# Compute
LoRA
for gate and up proj respectively
# Compute
lora
for gate and up proj respectively
lora_output
[:,
:
output_dim
]
=
self
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
:
lora_rank
].
contiguous
(),
weights
=
gate_up_lora_b
[
0
],
...
...
python/sglang/srt/lora/layers.py
View file @
e8e18dcd
...
...
@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
assert
(
B_buffer_q
.
shape
[
-
1
]
==
B_buffer_kv
.
shape
[
-
1
]
),
"The
LoRA
rank of q and kv should be the same when enabling fusion of qkv lora_b"
),
"The
lora
rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q
,
output_dim_kv
=
B_buffer_q
.
shape
[
-
2
],
B_buffer_kv
.
shape
[
-
2
]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
...
...
python/sglang/srt/lora/lora.py
View file @
e8e18dcd
...
...
@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
self
.
config
:
LoRAConfig
=
config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
#
LoRA
weights in cpu. The weights are loaded from checkpoint.
#
lora
weights in cpu. The weights are loaded from checkpoint.
self
.
weights
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
...
...
@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
def
stack_qkv_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]):
# Collect target q/k/v modules. This process is necessary since there might be no
LoRA
attached to k_proj
# Collect target q/k/v modules. This process is necessary since there might be no
lora
attached to k_proj
target_module
=
set
()
for
weight_name
in
weight_names
:
if
"k_proj"
in
weight_name
:
...
...
@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
return
for
weight_name
in
weight_names
:
# We assume every
LoRA
adaptor should contain
LoRA
modules for q_proj
# We assume every
lora
adaptor should contain
lora
modules for q_proj
if
"q_proj"
in
weight_name
:
q_name
=
weight_name
k_name
=
weight_name
.
replace
(
"q_proj"
,
"k_proj"
)
...
...
@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
kv_name
=
weight_name
.
replace
(
"q_proj"
,
"kv_proj"
)
qkv_name
=
weight_name
.
replace
(
"q_proj"
,
"qkv_proj"
)
# If k_proj doesn't have
LoRA
, initialize it to zero
# If k_proj doesn't have
lora
, initialize it to zero
k_proj_weight
=
(
weights
[
k_name
]
if
"k_proj"
in
target_module
...
...
python/sglang/srt/lora/lora_manager.py
View file @
e8e18dcd
...
...
@@ -93,14 +93,14 @@ class LoRAManager:
# Config of each LoRA adapter
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
# Target module names in
H
ugging
F
ace
LoRA
configs.
# Target module names in
h
ugging
f
ace
lora
configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self
.
hf_target_names
:
Set
[
str
]
=
set
()
for
name
,
path
in
self
.
lora_paths
.
items
():
self
.
configs
[
name
]
=
LoRAConfig
(
path
)
self
.
hf_target_names
.
update
(
self
.
configs
[
name
].
target_modules
)
# Target
LoRA
weight names for lora_a and lora_b modules respectively.
# Target
lora
weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
set
(
[
get_stacked_name
(
module
)
for
module
in
self
.
hf_target_names
]
...
...
@@ -119,11 +119,11 @@ class LoRAManager:
lora_adapter
.
initialize_weights
()
self
.
loras
[
name
]
=
lora_adapter
# misc
LoRA
configs
# misc
lora
configs
self
.
max_lora_dim
:
int
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
if
self
.
lora_backend
==
"flashinfer"
:
# FIXME
:
remove the restrictions after supporting multi-rank for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
scaling
=
list
(
self
.
loras
.
values
())[
0
].
scaling
assert
all
(
x
.
hf_config
[
"r"
]
==
max_lora_dim
for
x
in
self
.
configs
.
values
())
...
...
@@ -144,16 +144,16 @@ class LoRAManager:
self
.
lora_modules
,
)
# Initialize target
LoRA
modules in memory pool
# Initialize target
lora
modules in memory pool
self
.
memory_pool
.
init_buffers
(
self
.
lora_weight_names
,
self
.
base_model
)
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
# load active
LoRA
s into
LoRA
memory pool
# load active
lora
s into
lora
memory pool
cur_uids
=
set
(
forward_batch
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
# set up batch info shared by all
LoRA
modules
# set up batch info shared by all
lora
modules
bs
=
forward_batch
.
batch_size
if
(
...
...
@@ -221,7 +221,7 @@ class LoRAManager:
)
self
.
lora_backend
.
set_batch_info
(
batch_info
)
# call set_lora_info for each
LoRA
modules
# call set_lora_info for each
lora
modules
for
layer_id
,
modules
in
self
.
lora_modules
.
items
():
for
module_name
,
module
in
modules
:
if
"qkv_proj"
in
module_name
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment