Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d738ab52
Unverified
Commit
d738ab52
authored
May 12, 2025
by
applesaucethebun
Committed by
GitHub
May 13, 2025
Browse files
fix some typos (#6209)
Co-authored-by:
Brayden Zhong
<
b8zhong@uwaterloo.ca
>
parent
3ee40ff9
Changes
95
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
50 additions
and
50 deletions
+50
-50
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+3
-3
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+1
-1
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+1
-1
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+1
-1
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+1
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-1
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+5
-5
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+1
-1
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+3
-3
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-1
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+3
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+2
-2
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+2
-2
python/sglang/srt/lora/backend/base_backend.py
python/sglang/srt/lora/backend/base_backend.py
+8
-8
python/sglang/srt/lora/backend/flashinfer_backend.py
python/sglang/srt/lora/backend/flashinfer_backend.py
+1
-1
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+1
-1
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+4
-4
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+8
-8
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
d738ab52
...
...
@@ -20,7 +20,7 @@ class AttentionBackend(ABC):
raise
NotImplementedError
()
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Init the global shared states for
cuda
graph."""
"""Init the global shared states for
CUDA
graph."""
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
...
...
@@ -33,7 +33,7 @@ class AttentionBackend(ABC):
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
"""Init the metadata for a forward pass for capturing a
cuda
graph."""
"""Init the metadata for a forward pass for capturing a
CUDA
graph."""
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
"""Init the metadata for a forward pass for replaying a
cuda
graph."""
"""Init the metadata for a forward pass for replaying a
CUDA
graph."""
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
d738ab52
...
...
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
class
DoubleSparseAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of
cuda
context
# Lazy import to avoid the initialization of
CUDA
context
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
extend_attention_fwd
,
flash_decode_attention_fwd
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
d738ab52
...
...
@@ -664,7 +664,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr
=
kv_indptr
[:
bs
+
1
]
if
wrapper
.
is_cuda_graph_enabled
:
# Directly write to the
cuda
graph input buffer
# Directly write to the
CUDA
graph input buffer
kv_indices
=
wrapper
.
_paged_kv_indices_buf
else
:
kv_indices
=
torch
.
empty
(
...
...
@@ -1173,7 +1173,7 @@ def fast_decode_plan(
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the
cuda
graph buffers.
- Remove unnecessary device-to-device copy for the
CUDA
graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size
=
len
(
last_page_len
)
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
d738ab52
...
...
@@ -874,7 +874,7 @@ def fast_mla_decode_plan(
)
->
None
:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
cuda
graph replaying.
CUDA
graph replaying.
"""
self
.
_causal
=
causal
self
.
_page_size
=
page_size
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
d738ab52
...
...
@@ -92,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Lazy import to avoid the initialization of
cuda
context
# Lazy import to avoid the initialization of
CUDA
context
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
d738ab52
...
...
@@ -257,7 +257,7 @@ class VisionFlash3Attention(nn.Module):
**
kwargs
,
):
if
not
_is_cuda
:
raise
Exception
(
"VisionFlash3Attention is only available for
cuda
"
)
raise
Exception
(
"VisionFlash3Attention is only available for
CUDA
"
)
super
().
__init__
()
def
forward
(
...
...
python/sglang/srt/layers/dp_attention.py
View file @
d738ab52
...
...
@@ -237,7 +237,7 @@ def dp_scatter(
forward_batch
:
ForwardBatch
,
):
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for
cuda
graph
# since local_tokens may be padded for
CUDA
graph
local_start_pos
,
local_num_tokens
=
get_dp_local_info
(
forward_batch
)
local_tokens
.
fill_
(
0
)
...
...
python/sglang/srt/layers/logits_processor.py
View file @
d738ab52
...
...
@@ -166,7 +166,7 @@ class LogitsMetadata:
def
compute_dp_attention_metadata
(
self
,
hidden_states
:
torch
.
Tensor
):
if
self
.
global_num_tokens_for_logprob_cpu
is
None
:
# we are capturing
cuda
graph
# we are capturing
CUDA
graph
return
cumtokens
=
torch
.
cumsum
(
self
.
global_num_tokens_for_logprob_gpu
,
dim
=
0
)
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
d738ab52
...
...
@@ -38,7 +38,7 @@ try:
except
ImportError
:
VLLM_AVAILABLE
=
False
# Define empty classes as placeholders when v
llm
is not available
# Define empty classes as placeholders when v
LLM
is not available
class
DummyConfig
:
def
override_quantization_method
(
self
,
*
args
,
**
kwargs
):
return
None
...
...
@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if
quantization
in
VLLM_QUANTIZATION_METHODS
and
not
VLLM_AVAILABLE
:
raise
ValueError
(
f
"
{
quantization
}
quantization requires some operators from vllm. "
"Please install v
llm
by `pip install vllm==0.8.4`"
"Please install v
LLM
by `pip install vllm==0.8.4`"
)
return
QUANTIZATION_METHODS
[
quantization
]
...
...
@@ -231,7 +231,7 @@ original_isinstance = builtins.isinstance
def
monkey_patch_isinstance_for_vllm_base_layer
(
reverse
:
bool
=
False
):
"""
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize
sgl
ang layers
can recognize
SGL
ang layers
"""
if
not
VLLM_AVAILABLE
:
return
...
...
@@ -267,7 +267,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert
sgl
ang arguments to v
llm
arguments.
Convert
SGL
ang arguments to v
LLM
arguments.
"""
original_apply
=
class_obj
.
apply
sig
=
inspect
.
signature
(
original_apply
)
...
...
@@ -329,6 +329,6 @@ def monkey_patch_quant_configs():
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
# Only apply monkey patches if v
llm
is available
# Only apply monkey patches if v
LLM
is available
if
VLLM_AVAILABLE
:
monkey_patch_quant_configs
()
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
d738ab52
...
...
@@ -208,7 +208,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid
cuda
graph capturing issue
# Use torch Parameter to avoid
CUDA
graph capturing issue
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
d738ab52
...
...
@@ -363,7 +363,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"v
llm
is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install v
llm
"
"v
LLM
is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install v
LLM
"
)
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
...
...
@@ -409,7 +409,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"v
llm
is not installed, to use CompressedTensorsW8A16Fp8, please install v
llm
"
"v
LLM
is not installed, to use CompressedTensorsW8A16Fp8, please install v
LLM
"
)
is_static_input_scheme
=
input_quant
and
not
input_quant
.
dynamic
return
CompressedTensorsW8A16Fp8
(
...
...
@@ -491,7 +491,7 @@ class CompressedTensorsConfig(QuantizationConfig):
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"v
llm
is not installed, to use CompressedTensors24, please install v
llm
"
"v
LLM
is not installed, to use CompressedTensors24, please install v
LLM
"
)
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
d738ab52
...
...
@@ -65,7 +65,7 @@ class CompressedTensorsMoEMethod:
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"v
llm
is not installed, to use CompressedTensorsWNA16MoEMethod, please install v
llm
."
"v
LLM
is not installed, to use CompressedTensorsWNA16MoEMethod, please install v
LLM
."
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
d738ab52
...
...
@@ -27,10 +27,10 @@ except ImportError:
MARLIN_FP8_AVAILABLE
=
False
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
raise
ImportError
(
"v
llm
is not installed"
)
raise
ImportError
(
"v
LLM
is not installed"
)
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
raise
ImportError
(
"v
llm
is not installed"
)
raise
ImportError
(
"v
LLM
is not installed"
)
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
...
...
@@ -45,7 +45,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
if
not
MARLIN_FP8_AVAILABLE
:
raise
ImportError
(
"v
llm
is not installed. To use CompressedTensorsW8A16Fp8, please install v
llm
"
"v
LLM
is not installed. To use CompressedTensorsW8A16Fp8, please install v
LLM
"
)
@
classmethod
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
d738ab52
...
...
@@ -357,7 +357,7 @@ def apply_fp8_linear(
# Fused GEMM_DQ
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to v
llm
cutlass w8a8 fp8 kernel
# Fall back to v
LLM
cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
...
...
@@ -493,7 +493,7 @@ def apply_fp8_linear(
if
cutlass_fp8_supported
:
try
:
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to v
llm
cutlass w8a8 fp8 kernel
# Fall back to v
LLM
cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
d738ab52
...
...
@@ -186,8 +186,8 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per
lora
, we can run multiple
lora
in a batched way.
instead of running rotary embedding kernel per
LoRA adapter
, we can run multiple
LoRA adapters
in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
...
...
python/sglang/srt/lora/backend/base_backend.py
View file @
d738ab52
...
...
@@ -41,13 +41,13 @@ class BaseLoRABackend:
def
run_lora_a_sgemm
(
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Run segment Gemm of
lora
a modules with current backend.
"""Run segment Gemm of
LoRA
a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of
lora
weights with shape (num_lora, c * r, input_dim),
here r is
lora
rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
weights: a set of
LoRA
weights with shape (num_lora, c * r, input_dim),
here r is
LoRA
rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
Returns:
result with shape (s, c * r)
...
...
@@ -57,12 +57,12 @@ class BaseLoRABackend:
def
run_lora_b_sgemm
(
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Run segment Gemm of
lora
b modules with current backend.
"""Run segment Gemm of
LoRA
b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is
lora
rank
weights: a set of
lora
weights with shape (num_lora, output_dim, r)
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is
LoRA
rank
weights: a set of
LoRA
weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
...
...
@@ -77,7 +77,7 @@ class BaseLoRABackend:
*
args
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run the
lora
pass for QKV Layer.
"""Run the
LoRA
pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
...
...
@@ -100,7 +100,7 @@ class BaseLoRABackend:
*
args
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run the
lora
pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
"""Run the
LoRA
pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
...
...
python/sglang/srt/lora/backend/flashinfer_backend.py
View file @
d738ab52
...
...
@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
dtype
=
x
.
dtype
,
)
# Compute
lora
for gate and up proj respectively
# Compute
LoRA
for gate and up proj respectively
lora_output
[:,
:
output_dim
]
=
self
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
:
lora_rank
].
contiguous
(),
weights
=
gate_up_lora_b
[
0
],
...
...
python/sglang/srt/lora/layers.py
View file @
d738ab52
...
...
@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
assert
(
B_buffer_q
.
shape
[
-
1
]
==
B_buffer_kv
.
shape
[
-
1
]
),
"The
lora
rank of q and kv should be the same when enabling fusion of qkv lora_b"
),
"The
LoRA
rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q
,
output_dim_kv
=
B_buffer_q
.
shape
[
-
2
],
B_buffer_kv
.
shape
[
-
2
]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
...
...
python/sglang/srt/lora/lora.py
View file @
d738ab52
...
...
@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
self
.
config
:
LoRAConfig
=
config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
#
lora
weights in cpu. The weights are loaded from checkpoint.
#
LoRA
weights in cpu. The weights are loaded from checkpoint.
self
.
weights
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
...
...
@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
def
stack_qkv_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]):
# Collect target q/k/v modules. This process is necessary since there might be no
lora
attached to k_proj
# Collect target q/k/v modules. This process is necessary since there might be no
LoRA
attached to k_proj
target_module
=
set
()
for
weight_name
in
weight_names
:
if
"k_proj"
in
weight_name
:
...
...
@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
return
for
weight_name
in
weight_names
:
# We assume every
lora
adaptor should contain
lora
modules for q_proj
# We assume every
LoRA
adaptor should contain
LoRA
modules for q_proj
if
"q_proj"
in
weight_name
:
q_name
=
weight_name
k_name
=
weight_name
.
replace
(
"q_proj"
,
"k_proj"
)
...
...
@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
kv_name
=
weight_name
.
replace
(
"q_proj"
,
"kv_proj"
)
qkv_name
=
weight_name
.
replace
(
"q_proj"
,
"qkv_proj"
)
# If k_proj doesn't have
lora
, initialize it to zero
# If k_proj doesn't have
LoRA
, initialize it to zero
k_proj_weight
=
(
weights
[
k_name
]
if
"k_proj"
in
target_module
...
...
python/sglang/srt/lora/lora_manager.py
View file @
d738ab52
...
...
@@ -93,14 +93,14 @@ class LoRAManager:
# Config of each LoRA adapter
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
# Target module names in
h
ugging
f
ace
lora
configs.
# Target module names in
H
ugging
F
ace
LoRA
configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self
.
hf_target_names
:
Set
[
str
]
=
set
()
for
name
,
path
in
self
.
lora_paths
.
items
():
self
.
configs
[
name
]
=
LoRAConfig
(
path
)
self
.
hf_target_names
.
update
(
self
.
configs
[
name
].
target_modules
)
# Target
lora
weight names for lora_a and lora_b modules respectively.
# Target
LoRA
weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
set
(
[
get_stacked_name
(
module
)
for
module
in
self
.
hf_target_names
]
...
...
@@ -119,11 +119,11 @@ class LoRAManager:
lora_adapter
.
initialize_weights
()
self
.
loras
[
name
]
=
lora_adapter
# misc
lora
configs
# misc
LoRA
configs
self
.
max_lora_dim
:
int
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
if
self
.
lora_backend
==
"flashinfer"
:
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
# FIXME
:
remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
scaling
=
list
(
self
.
loras
.
values
())[
0
].
scaling
assert
all
(
x
.
hf_config
[
"r"
]
==
max_lora_dim
for
x
in
self
.
configs
.
values
())
...
...
@@ -144,16 +144,16 @@ class LoRAManager:
self
.
lora_modules
,
)
# Initialize target
lora
modules in memory pool
# Initialize target
LoRA
modules in memory pool
self
.
memory_pool
.
init_buffers
(
self
.
lora_weight_names
,
self
.
base_model
)
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
# load active
lora
s into
lora
memory pool
# load active
LoRA
s into
LoRA
memory pool
cur_uids
=
set
(
forward_batch
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
# set up batch info shared by all
lora
modules
# set up batch info shared by all
LoRA
modules
bs
=
forward_batch
.
batch_size
if
(
...
...
@@ -221,7 +221,7 @@ class LoRAManager:
)
self
.
lora_backend
.
set_batch_info
(
batch_info
)
# call set_lora_info for each
lora
modules
# call set_lora_info for each
LoRA
modules
for
layer_id
,
modules
in
self
.
lora_modules
.
items
():
for
module_name
,
module
in
modules
:
if
"qkv_proj"
in
module_name
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment