Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e8e18dcd
Unverified
Commit
e8e18dcd
authored
May 12, 2025
by
Lianmin Zheng
Committed by
GitHub
May 12, 2025
Browse files
Revert "fix some typos" (#6244)
parent
bad7c26f
Changes
95
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
50 additions
and
50 deletions
+50
-50
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+3
-3
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+1
-1
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+1
-1
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+1
-1
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+1
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-1
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+5
-5
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+1
-1
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+3
-3
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-1
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+3
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+2
-2
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+2
-2
python/sglang/srt/lora/backend/base_backend.py
python/sglang/srt/lora/backend/base_backend.py
+8
-8
python/sglang/srt/lora/backend/flashinfer_backend.py
python/sglang/srt/lora/backend/flashinfer_backend.py
+1
-1
python/sglang/srt/lora/layers.py
python/sglang/srt/lora/layers.py
+1
-1
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+4
-4
python/sglang/srt/lora/lora_manager.py
python/sglang/srt/lora/lora_manager.py
+8
-8
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
e8e18dcd
...
@@ -20,7 +20,7 @@ class AttentionBackend(ABC):
...
@@ -20,7 +20,7 @@ class AttentionBackend(ABC):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Init the global shared states for
CUDA
graph."""
"""Init the global shared states for
cuda
graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
...
@@ -33,7 +33,7 @@ class AttentionBackend(ABC):
...
@@ -33,7 +33,7 @@ class AttentionBackend(ABC):
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
):
"""Init the metadata for a forward pass for capturing a
CUDA
graph."""
"""Init the metadata for a forward pass for capturing a
cuda
graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
...
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
...
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
"""Init the metadata for a forward pass for replaying a
CUDA
graph."""
"""Init the metadata for a forward pass for replaying a
cuda
graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
e8e18dcd
...
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
...
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
class
DoubleSparseAttnBackend
(
AttentionBackend
):
class
DoubleSparseAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of
CUDA
context
# Lazy import to avoid the initialization of
cuda
context
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
extend_attention_fwd
,
extend_attention_fwd
,
flash_decode_attention_fwd
,
flash_decode_attention_fwd
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
e8e18dcd
...
@@ -664,7 +664,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -664,7 +664,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
if
wrapper
.
is_cuda_graph_enabled
:
if
wrapper
.
is_cuda_graph_enabled
:
# Directly write to the
CUDA
graph input buffer
# Directly write to the
cuda
graph input buffer
kv_indices
=
wrapper
.
_paged_kv_indices_buf
kv_indices
=
wrapper
.
_paged_kv_indices_buf
else
:
else
:
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
empty
(
...
@@ -1173,7 +1173,7 @@ def fast_decode_plan(
...
@@ -1173,7 +1173,7 @@ def fast_decode_plan(
"""
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
Modifications:
- Remove unnecessary device-to-device copy for the
CUDA
graph buffers.
- Remove unnecessary device-to-device copy for the
cuda
graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
"""
batch_size
=
len
(
last_page_len
)
batch_size
=
len
(
last_page_len
)
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
e8e18dcd
...
@@ -874,7 +874,7 @@ def fast_mla_decode_plan(
...
@@ -874,7 +874,7 @@ def fast_mla_decode_plan(
)
->
None
:
)
->
None
:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
for skipping the stream synchronization in original plan function during
CUDA
graph replaying.
cuda
graph replaying.
"""
"""
self
.
_causal
=
causal
self
.
_causal
=
causal
self
.
_page_size
=
page_size
self
.
_page_size
=
page_size
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
e8e18dcd
...
@@ -92,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -92,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
skip_prefill
:
bool
=
False
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
# Lazy import to avoid the initialization of
CUDA
context
# Lazy import to avoid the initialization of
cuda
context
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
decode_attention_fwd
,
)
)
...
...
python/sglang/srt/layers/attention/vision.py
View file @
e8e18dcd
...
@@ -257,7 +257,7 @@ class VisionFlash3Attention(nn.Module):
...
@@ -257,7 +257,7 @@ class VisionFlash3Attention(nn.Module):
**
kwargs
,
**
kwargs
,
):
):
if
not
_is_cuda
:
if
not
_is_cuda
:
raise
Exception
(
"VisionFlash3Attention is only available for
CUDA
"
)
raise
Exception
(
"VisionFlash3Attention is only available for
cuda
"
)
super
().
__init__
()
super
().
__init__
()
def
forward
(
def
forward
(
...
...
python/sglang/srt/layers/dp_attention.py
View file @
e8e18dcd
...
@@ -237,7 +237,7 @@ def dp_scatter(
...
@@ -237,7 +237,7 @@ def dp_scatter(
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
):
):
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for
CUDA
graph
# since local_tokens may be padded for
cuda
graph
local_start_pos
,
local_num_tokens
=
get_dp_local_info
(
forward_batch
)
local_start_pos
,
local_num_tokens
=
get_dp_local_info
(
forward_batch
)
local_tokens
.
fill_
(
0
)
local_tokens
.
fill_
(
0
)
...
...
python/sglang/srt/layers/logits_processor.py
View file @
e8e18dcd
...
@@ -166,7 +166,7 @@ class LogitsMetadata:
...
@@ -166,7 +166,7 @@ class LogitsMetadata:
def
compute_dp_attention_metadata
(
self
,
hidden_states
:
torch
.
Tensor
):
def
compute_dp_attention_metadata
(
self
,
hidden_states
:
torch
.
Tensor
):
if
self
.
global_num_tokens_for_logprob_cpu
is
None
:
if
self
.
global_num_tokens_for_logprob_cpu
is
None
:
# we are capturing
CUDA
graph
# we are capturing
cuda
graph
return
return
cumtokens
=
torch
.
cumsum
(
self
.
global_num_tokens_for_logprob_gpu
,
dim
=
0
)
cumtokens
=
torch
.
cumsum
(
self
.
global_num_tokens_for_logprob_gpu
,
dim
=
0
)
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
e8e18dcd
...
@@ -38,7 +38,7 @@ try:
...
@@ -38,7 +38,7 @@ try:
except
ImportError
:
except
ImportError
:
VLLM_AVAILABLE
=
False
VLLM_AVAILABLE
=
False
# Define empty classes as placeholders when v
LLM
is not available
# Define empty classes as placeholders when v
llm
is not available
class
DummyConfig
:
class
DummyConfig
:
def
override_quantization_method
(
self
,
*
args
,
**
kwargs
):
def
override_quantization_method
(
self
,
*
args
,
**
kwargs
):
return
None
return
None
...
@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
...
@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if
quantization
in
VLLM_QUANTIZATION_METHODS
and
not
VLLM_AVAILABLE
:
if
quantization
in
VLLM_QUANTIZATION_METHODS
and
not
VLLM_AVAILABLE
:
raise
ValueError
(
raise
ValueError
(
f
"
{
quantization
}
quantization requires some operators from vllm. "
f
"
{
quantization
}
quantization requires some operators from vllm. "
"Please install v
LLM
by `pip install vllm==0.8.4`"
"Please install v
llm
by `pip install vllm==0.8.4`"
)
)
return
QUANTIZATION_METHODS
[
quantization
]
return
QUANTIZATION_METHODS
[
quantization
]
...
@@ -231,7 +231,7 @@ original_isinstance = builtins.isinstance
...
@@ -231,7 +231,7 @@ original_isinstance = builtins.isinstance
def
monkey_patch_isinstance_for_vllm_base_layer
(
reverse
:
bool
=
False
):
def
monkey_patch_isinstance_for_vllm_base_layer
(
reverse
:
bool
=
False
):
"""
"""
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize
SGL
ang layers
can recognize
sgl
ang layers
"""
"""
if
not
VLLM_AVAILABLE
:
if
not
VLLM_AVAILABLE
:
return
return
...
@@ -267,7 +267,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
...
@@ -267,7 +267,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
"""
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert
SGL
ang arguments to v
LLM
arguments.
Convert
sgl
ang arguments to v
llm
arguments.
"""
"""
original_apply
=
class_obj
.
apply
original_apply
=
class_obj
.
apply
sig
=
inspect
.
signature
(
original_apply
)
sig
=
inspect
.
signature
(
original_apply
)
...
@@ -329,6 +329,6 @@ def monkey_patch_quant_configs():
...
@@ -329,6 +329,6 @@ def monkey_patch_quant_configs():
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
# Only apply monkey patches if v
LLM
is available
# Only apply monkey patches if v
llm
is available
if
VLLM_AVAILABLE
:
if
VLLM_AVAILABLE
:
monkey_patch_quant_configs
()
monkey_patch_quant_configs
()
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
e8e18dcd
...
@@ -208,7 +208,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
...
@@ -208,7 +208,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Block quant doesn't need to process weights after loading
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid
CUDA
graph capturing issue
# Use torch Parameter to avoid
cuda
graph capturing issue
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
e8e18dcd
...
@@ -363,7 +363,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -363,7 +363,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
raise
ImportError
(
"v
LLM
is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install v
LLM
"
"v
llm
is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install v
llm
"
)
)
if
(
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
...
@@ -409,7 +409,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -409,7 +409,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
raise
ImportError
(
"v
LLM
is not installed, to use CompressedTensorsW8A16Fp8, please install v
LLM
"
"v
llm
is not installed, to use CompressedTensorsW8A16Fp8, please install v
llm
"
)
)
is_static_input_scheme
=
input_quant
and
not
input_quant
.
dynamic
is_static_input_scheme
=
input_quant
and
not
input_quant
.
dynamic
return
CompressedTensorsW8A16Fp8
(
return
CompressedTensorsW8A16Fp8
(
...
@@ -491,7 +491,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -491,7 +491,7 @@ class CompressedTensorsConfig(QuantizationConfig):
):
):
if
not
VLLM_AVAILABLE
:
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
raise
ImportError
(
"v
LLM
is not installed, to use CompressedTensors24, please install v
LLM
"
"v
llm
is not installed, to use CompressedTensors24, please install v
llm
"
)
)
# Have a valid sparsity scheme
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
# Validate layer is supported by Cutlass 2:4 Kernel
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
e8e18dcd
...
@@ -65,7 +65,7 @@ class CompressedTensorsMoEMethod:
...
@@ -65,7 +65,7 @@ class CompressedTensorsMoEMethod:
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
raise
ImportError
(
"v
LLM
is not installed, to use CompressedTensorsWNA16MoEMethod, please install v
LLM
."
"v
llm
is not installed, to use CompressedTensorsWNA16MoEMethod, please install v
llm
."
)
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
e8e18dcd
...
@@ -27,10 +27,10 @@ except ImportError:
...
@@ -27,10 +27,10 @@ except ImportError:
MARLIN_FP8_AVAILABLE
=
False
MARLIN_FP8_AVAILABLE
=
False
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
raise
ImportError
(
"v
LLM
is not installed"
)
raise
ImportError
(
"v
llm
is not installed"
)
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
raise
ImportError
(
"v
LLM
is not installed"
)
raise
ImportError
(
"v
llm
is not installed"
)
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
...
@@ -45,7 +45,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
...
@@ -45,7 +45,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
if
not
MARLIN_FP8_AVAILABLE
:
if
not
MARLIN_FP8_AVAILABLE
:
raise
ImportError
(
raise
ImportError
(
"v
LLM
is not installed. To use CompressedTensorsW8A16Fp8, please install v
LLM
"
"v
llm
is not installed. To use CompressedTensorsW8A16Fp8, please install v
llm
"
)
)
@
classmethod
@
classmethod
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
e8e18dcd
...
@@ -357,7 +357,7 @@ def apply_fp8_linear(
...
@@ -357,7 +357,7 @@ def apply_fp8_linear(
# Fused GEMM_DQ
# Fused GEMM_DQ
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to v
LLM
cutlass w8a8 fp8 kernel
# Fall back to v
llm
cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
qinput
,
weight
,
weight
,
...
@@ -493,7 +493,7 @@ def apply_fp8_linear(
...
@@ -493,7 +493,7 @@ def apply_fp8_linear(
if
cutlass_fp8_supported
:
if
cutlass_fp8_supported
:
try
:
try
:
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
if
VLLM_AVAILABLE
and
use_vllm_cutlass_w8a8_fp8_kernel
:
# Fall back to v
LLM
cutlass w8a8 fp8 kernel
# Fall back to v
llm
cutlass w8a8 fp8 kernel
output
=
ops
.
cutlass_scaled_mm
(
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
qinput
,
weight
,
weight
,
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
e8e18dcd
...
@@ -186,8 +186,8 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -186,8 +186,8 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
It supports multiple scaling factors. Since multiple LoRA adapters may have
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per
LoRA adapter
, we can run multiple
instead of running rotary embedding kernel per
lora
, we can run multiple
LoRA adapters
in a batched way.
lora
in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
of 1 (default) at all times.
...
...
python/sglang/srt/lora/backend/base_backend.py
View file @
e8e18dcd
...
@@ -41,13 +41,13 @@ class BaseLoRABackend:
...
@@ -41,13 +41,13 @@ class BaseLoRABackend:
def
run_lora_a_sgemm
(
def
run_lora_a_sgemm
(
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Run segment Gemm of
LoRA
a modules with current backend.
"""Run segment Gemm of
lora
a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of
LoRA
weights with shape (num_lora, c * r, input_dim),
weights: a set of
lora
weights with shape (num_lora, c * r, input_dim),
here r is
LoRA
rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
here r is
lora
rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
usually input_dim is much larger than r
Returns:
Returns:
result with shape (s, c * r)
result with shape (s, c * r)
...
@@ -57,12 +57,12 @@ class BaseLoRABackend:
...
@@ -57,12 +57,12 @@ class BaseLoRABackend:
def
run_lora_b_sgemm
(
def
run_lora_b_sgemm
(
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
self
,
x
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Run segment Gemm of
LoRA
b modules with current backend.
"""Run segment Gemm of
lora
b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is
LoRA
rank
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is
lora
rank
weights: a set of
LoRA
weights with shape (num_lora, output_dim, r)
weights: a set of
lora
weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
usually output_dim is much larger than r
Returns:
Returns:
result with shape (s, output_dim)
result with shape (s, output_dim)
...
@@ -77,7 +77,7 @@ class BaseLoRABackend:
...
@@ -77,7 +77,7 @@ class BaseLoRABackend:
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Run the
LoRA
pass for QKV Layer.
"""Run the
lora
pass for QKV Layer.
Args:
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
...
@@ -100,7 +100,7 @@ class BaseLoRABackend:
...
@@ -100,7 +100,7 @@ class BaseLoRABackend:
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Run the
LoRA
pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
"""Run the
lora
pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
...
...
python/sglang/srt/lora/backend/flashinfer_backend.py
View file @
e8e18dcd
...
@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
...
@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
dtype
=
x
.
dtype
,
dtype
=
x
.
dtype
,
)
)
# Compute
LoRA
for gate and up proj respectively
# Compute
lora
for gate and up proj respectively
lora_output
[:,
:
output_dim
]
=
self
.
run_lora_b_sgemm
(
lora_output
[:,
:
output_dim
]
=
self
.
run_lora_b_sgemm
(
x
=
lora_a_output
[:,
:
lora_rank
].
contiguous
(),
x
=
lora_a_output
[:,
:
lora_rank
].
contiguous
(),
weights
=
gate_up_lora_b
[
0
],
weights
=
gate_up_lora_b
[
0
],
...
...
python/sglang/srt/lora/layers.py
View file @
e8e18dcd
...
@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
if
self
.
lora_backend
.
fuse_stacked_lora_b
:
assert
(
assert
(
B_buffer_q
.
shape
[
-
1
]
==
B_buffer_kv
.
shape
[
-
1
]
B_buffer_q
.
shape
[
-
1
]
==
B_buffer_kv
.
shape
[
-
1
]
),
"The
LoRA
rank of q and kv should be the same when enabling fusion of qkv lora_b"
),
"The
lora
rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q
,
output_dim_kv
=
B_buffer_q
.
shape
[
-
2
],
B_buffer_kv
.
shape
[
-
2
]
output_dim_q
,
output_dim_kv
=
B_buffer_q
.
shape
[
-
2
],
B_buffer_kv
.
shape
[
-
2
]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
...
...
python/sglang/srt/lora/lora.py
View file @
e8e18dcd
...
@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
...
@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
self
.
config
:
LoRAConfig
=
config
self
.
config
:
LoRAConfig
=
config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
self
.
base_hf_config
:
AutoConfig
=
base_hf_config
#
LoRA
weights in cpu. The weights are loaded from checkpoint.
#
lora
weights in cpu. The weights are loaded from checkpoint.
self
.
weights
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
weights
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
...
@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
...
@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
def
stack_qkv_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]):
def
stack_qkv_proj
(
self
,
weight_names
:
List
[
str
],
weights
:
Dict
[
str
,
torch
.
Tensor
]):
# Collect target q/k/v modules. This process is necessary since there might be no
LoRA
attached to k_proj
# Collect target q/k/v modules. This process is necessary since there might be no
lora
attached to k_proj
target_module
=
set
()
target_module
=
set
()
for
weight_name
in
weight_names
:
for
weight_name
in
weight_names
:
if
"k_proj"
in
weight_name
:
if
"k_proj"
in
weight_name
:
...
@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
...
@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
return
return
for
weight_name
in
weight_names
:
for
weight_name
in
weight_names
:
# We assume every
LoRA
adaptor should contain
LoRA
modules for q_proj
# We assume every
lora
adaptor should contain
lora
modules for q_proj
if
"q_proj"
in
weight_name
:
if
"q_proj"
in
weight_name
:
q_name
=
weight_name
q_name
=
weight_name
k_name
=
weight_name
.
replace
(
"q_proj"
,
"k_proj"
)
k_name
=
weight_name
.
replace
(
"q_proj"
,
"k_proj"
)
...
@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
...
@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
kv_name
=
weight_name
.
replace
(
"q_proj"
,
"kv_proj"
)
kv_name
=
weight_name
.
replace
(
"q_proj"
,
"kv_proj"
)
qkv_name
=
weight_name
.
replace
(
"q_proj"
,
"qkv_proj"
)
qkv_name
=
weight_name
.
replace
(
"q_proj"
,
"qkv_proj"
)
# If k_proj doesn't have
LoRA
, initialize it to zero
# If k_proj doesn't have
lora
, initialize it to zero
k_proj_weight
=
(
k_proj_weight
=
(
weights
[
k_name
]
weights
[
k_name
]
if
"k_proj"
in
target_module
if
"k_proj"
in
target_module
...
...
python/sglang/srt/lora/lora_manager.py
View file @
e8e18dcd
...
@@ -93,14 +93,14 @@ class LoRAManager:
...
@@ -93,14 +93,14 @@ class LoRAManager:
# Config of each LoRA adapter
# Config of each LoRA adapter
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
self
.
configs
:
Dict
[
str
,
LoRAConfig
]
=
{}
# Target module names in
H
ugging
F
ace
LoRA
configs.
# Target module names in
h
ugging
f
ace
lora
configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self
.
hf_target_names
:
Set
[
str
]
=
set
()
self
.
hf_target_names
:
Set
[
str
]
=
set
()
for
name
,
path
in
self
.
lora_paths
.
items
():
for
name
,
path
in
self
.
lora_paths
.
items
():
self
.
configs
[
name
]
=
LoRAConfig
(
path
)
self
.
configs
[
name
]
=
LoRAConfig
(
path
)
self
.
hf_target_names
.
update
(
self
.
configs
[
name
].
target_modules
)
self
.
hf_target_names
.
update
(
self
.
configs
[
name
].
target_modules
)
# Target
LoRA
weight names for lora_a and lora_b modules respectively.
# Target
lora
weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
set
(
self
.
lora_weight_names
:
Set
[
Tuple
[
str
]]
=
set
(
[
get_stacked_name
(
module
)
for
module
in
self
.
hf_target_names
]
[
get_stacked_name
(
module
)
for
module
in
self
.
hf_target_names
]
...
@@ -119,11 +119,11 @@ class LoRAManager:
...
@@ -119,11 +119,11 @@ class LoRAManager:
lora_adapter
.
initialize_weights
()
lora_adapter
.
initialize_weights
()
self
.
loras
[
name
]
=
lora_adapter
self
.
loras
[
name
]
=
lora_adapter
# misc
LoRA
configs
# misc
lora
configs
self
.
max_lora_dim
:
int
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
self
.
max_lora_dim
:
int
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
if
self
.
lora_backend
==
"flashinfer"
:
if
self
.
lora_backend
==
"flashinfer"
:
# FIXME
:
remove the restrictions after supporting multi-rank for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
max_lora_dim
=
max
([
x
.
hf_config
[
"r"
]
for
x
in
self
.
configs
.
values
()])
scaling
=
list
(
self
.
loras
.
values
())[
0
].
scaling
scaling
=
list
(
self
.
loras
.
values
())[
0
].
scaling
assert
all
(
x
.
hf_config
[
"r"
]
==
max_lora_dim
for
x
in
self
.
configs
.
values
())
assert
all
(
x
.
hf_config
[
"r"
]
==
max_lora_dim
for
x
in
self
.
configs
.
values
())
...
@@ -144,16 +144,16 @@ class LoRAManager:
...
@@ -144,16 +144,16 @@ class LoRAManager:
self
.
lora_modules
,
self
.
lora_modules
,
)
)
# Initialize target
LoRA
modules in memory pool
# Initialize target
lora
modules in memory pool
self
.
memory_pool
.
init_buffers
(
self
.
lora_weight_names
,
self
.
base_model
)
self
.
memory_pool
.
init_buffers
(
self
.
lora_weight_names
,
self
.
base_model
)
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
def
prepare_lora_batch
(
self
,
forward_batch
:
ForwardBatch
):
# load active
LoRA
s into
LoRA
memory pool
# load active
lora
s into
lora
memory pool
cur_uids
=
set
(
forward_batch
.
lora_paths
)
cur_uids
=
set
(
forward_batch
.
lora_paths
)
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
assert
len
(
cur_uids
)
<=
self
.
max_loras_per_batch
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
self
.
memory_pool
.
prepare_lora_batch
(
cur_uids
,
self
.
loras
)
# set up batch info shared by all
LoRA
modules
# set up batch info shared by all
lora
modules
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
if
(
if
(
...
@@ -221,7 +221,7 @@ class LoRAManager:
...
@@ -221,7 +221,7 @@ class LoRAManager:
)
)
self
.
lora_backend
.
set_batch_info
(
batch_info
)
self
.
lora_backend
.
set_batch_info
(
batch_info
)
# call set_lora_info for each
LoRA
modules
# call set_lora_info for each
lora
modules
for
layer_id
,
modules
in
self
.
lora_modules
.
items
():
for
layer_id
,
modules
in
self
.
lora_modules
.
items
():
for
module_name
,
module
in
modules
:
for
module_name
,
module
in
modules
:
if
"qkv_proj"
in
module_name
:
if
"qkv_proj"
in
module_name
:
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment