Unverified Commit e8e18dcd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "fix some typos" (#6244)

parent bad7c26f
...@@ -20,7 +20,7 @@ class AttentionBackend(ABC): ...@@ -20,7 +20,7 @@ class AttentionBackend(ABC):
raise NotImplementedError() raise NotImplementedError()
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
"""Init the global shared states for CUDA graph.""" """Init the global shared states for cuda graph."""
raise NotImplementedError() raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
...@@ -33,7 +33,7 @@ class AttentionBackend(ABC): ...@@ -33,7 +33,7 @@ class AttentionBackend(ABC):
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
"""Init the metadata for a forward pass for capturing a CUDA graph.""" """Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
...@@ -47,7 +47,7 @@ class AttentionBackend(ABC): ...@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
"""Init the metadata for a forward pass for replaying a CUDA graph.""" """Init the metadata for a forward pass for replaying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
......
...@@ -15,7 +15,7 @@ if TYPE_CHECKING: ...@@ -15,7 +15,7 @@ if TYPE_CHECKING:
class DoubleSparseAttnBackend(AttentionBackend): class DoubleSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner): def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of CUDA context # Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import ( from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
extend_attention_fwd, extend_attention_fwd,
flash_decode_attention_fwd, flash_decode_attention_fwd,
......
...@@ -664,7 +664,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -664,7 +664,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
if wrapper.is_cuda_graph_enabled: if wrapper.is_cuda_graph_enabled:
# Directly write to the CUDA graph input buffer # Directly write to the cuda graph input buffer
kv_indices = wrapper._paged_kv_indices_buf kv_indices = wrapper._paged_kv_indices_buf
else: else:
kv_indices = torch.empty( kv_indices = torch.empty(
...@@ -1173,7 +1173,7 @@ def fast_decode_plan( ...@@ -1173,7 +1173,7 @@ def fast_decode_plan(
""" """
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend. A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications: Modifications:
- Remove unnecessary device-to-device copy for the CUDA graph buffers. - Remove unnecessary device-to-device copy for the cuda graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers. - Remove unnecessary host-to-device copy for the metadata buffers.
""" """
batch_size = len(last_page_len) batch_size = len(last_page_len)
......
...@@ -874,7 +874,7 @@ def fast_mla_decode_plan( ...@@ -874,7 +874,7 @@ def fast_mla_decode_plan(
) -> None: ) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan, """A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during for skipping the stream synchronization in original plan function during
CUDA graph replaying. cuda graph replaying.
""" """
self._causal = causal self._causal = causal
self._page_size = page_size self._page_size = page_size
......
...@@ -92,7 +92,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -92,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
skip_prefill: bool = False, skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None, kv_indptr_buf: Optional[torch.Tensor] = None,
): ):
# Lazy import to avoid the initialization of CUDA context # Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.decode_attention import ( from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd, decode_attention_fwd,
) )
......
...@@ -257,7 +257,7 @@ class VisionFlash3Attention(nn.Module): ...@@ -257,7 +257,7 @@ class VisionFlash3Attention(nn.Module):
**kwargs, **kwargs,
): ):
if not _is_cuda: if not _is_cuda:
raise Exception("VisionFlash3Attention is only available for CUDA") raise Exception("VisionFlash3Attention is only available for cuda")
super().__init__() super().__init__()
def forward( def forward(
......
...@@ -237,7 +237,7 @@ def dp_scatter( ...@@ -237,7 +237,7 @@ def dp_scatter(
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
): ):
# local_num_tokens is not necessarily the same as local_tokens.shape[0], # local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for CUDA graph # since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0) local_tokens.fill_(0)
......
...@@ -166,7 +166,7 @@ class LogitsMetadata: ...@@ -166,7 +166,7 @@ class LogitsMetadata:
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
if self.global_num_tokens_for_logprob_cpu is None: if self.global_num_tokens_for_logprob_cpu is None:
# we are capturing CUDA graph # we are capturing cuda graph
return return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
......
...@@ -38,7 +38,7 @@ try: ...@@ -38,7 +38,7 @@ try:
except ImportError: except ImportError:
VLLM_AVAILABLE = False VLLM_AVAILABLE = False
# Define empty classes as placeholders when vLLM is not available # Define empty classes as placeholders when vllm is not available
class DummyConfig: class DummyConfig:
def override_quantization_method(self, *args, **kwargs): def override_quantization_method(self, *args, **kwargs):
return None return None
...@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE: if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
raise ValueError( raise ValueError(
f"{quantization} quantization requires some operators from vllm. " f"{quantization} quantization requires some operators from vllm. "
"Please install vLLM by `pip install vllm==0.8.4`" "Please install vllm by `pip install vllm==0.8.4`"
) )
return QUANTIZATION_METHODS[quantization] return QUANTIZATION_METHODS[quantization]
...@@ -231,7 +231,7 @@ original_isinstance = builtins.isinstance ...@@ -231,7 +231,7 @@ original_isinstance = builtins.isinstance
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
""" """
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize SGLang layers can recognize sglang layers
""" """
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
return return
...@@ -267,7 +267,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): ...@@ -267,7 +267,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
""" """
Monkey patch the apply function of vllm's FusedMoEMethodBase. Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert SGLang arguments to vLLM arguments. Convert sglang arguments to vllm arguments.
""" """
original_apply = class_obj.apply original_apply = class_obj.apply
sig = inspect.signature(original_apply) sig = inspect.signature(original_apply)
...@@ -329,6 +329,6 @@ def monkey_patch_quant_configs(): ...@@ -329,6 +329,6 @@ def monkey_patch_quant_configs():
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
# Only apply monkey patches if vLLM is available # Only apply monkey patches if vllm is available
if VLLM_AVAILABLE: if VLLM_AVAILABLE:
monkey_patch_quant_configs() monkey_patch_quant_configs()
...@@ -208,7 +208,7 @@ class BlockInt8LinearMethod(LinearMethodBase): ...@@ -208,7 +208,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid CUDA graph capturing issue # Use torch Parameter to avoid cuda graph capturing issue
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter( layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False layer.weight_scale_inv.data, requires_grad=False
......
...@@ -363,7 +363,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -363,7 +363,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
raise ImportError( raise ImportError(
"vLLM is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vLLM" "vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
) )
if ( if (
self.quant_format == CompressionFormat.marlin_24.value self.quant_format == CompressionFormat.marlin_24.value
...@@ -409,7 +409,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -409,7 +409,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_fp8_w8a16(weight_quant, input_quant): if self._is_fp8_w8a16(weight_quant, input_quant):
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
raise ImportError( raise ImportError(
"vLLM is not installed, to use CompressedTensorsW8A16Fp8, please install vLLM" "vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
) )
is_static_input_scheme = input_quant and not input_quant.dynamic is_static_input_scheme = input_quant and not input_quant.dynamic
return CompressedTensorsW8A16Fp8( return CompressedTensorsW8A16Fp8(
...@@ -491,7 +491,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -491,7 +491,7 @@ class CompressedTensorsConfig(QuantizationConfig):
): ):
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
raise ImportError( raise ImportError(
"vLLM is not installed, to use CompressedTensors24, please install vLLM" "vllm is not installed, to use CompressedTensors24, please install vllm"
) )
# Have a valid sparsity scheme # Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel # Validate layer is supported by Cutlass 2:4 Kernel
......
...@@ -65,7 +65,7 @@ class CompressedTensorsMoEMethod: ...@@ -65,7 +65,7 @@ class CompressedTensorsMoEMethod:
if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
raise ImportError( raise ImportError(
"vLLM is not installed, to use CompressedTensorsWNA16MoEMethod, please install vLLM." "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
) )
return CompressedTensorsWNA16MoEMethod(quant_config) return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
......
...@@ -27,10 +27,10 @@ except ImportError: ...@@ -27,10 +27,10 @@ except ImportError:
MARLIN_FP8_AVAILABLE = False MARLIN_FP8_AVAILABLE = False
def apply_fp8_marlin_linear(*args, **kwargs): def apply_fp8_marlin_linear(*args, **kwargs):
raise ImportError("vLLM is not installed") raise ImportError("vllm is not installed")
def prepare_fp8_layer_for_marlin(*args, **kwargs): def prepare_fp8_layer_for_marlin(*args, **kwargs):
raise ImportError("vLLM is not installed") raise ImportError("vllm is not installed")
__all__ = ["CompressedTensorsW8A16Fp8"] __all__ = ["CompressedTensorsW8A16Fp8"]
...@@ -45,7 +45,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -45,7 +45,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
if not MARLIN_FP8_AVAILABLE: if not MARLIN_FP8_AVAILABLE:
raise ImportError( raise ImportError(
"vLLM is not installed. To use CompressedTensorsW8A16Fp8, please install vLLM" "vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
) )
@classmethod @classmethod
......
...@@ -357,7 +357,7 @@ def apply_fp8_linear( ...@@ -357,7 +357,7 @@ def apply_fp8_linear(
# Fused GEMM_DQ # Fused GEMM_DQ
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vLLM cutlass w8a8 fp8 kernel # Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm( output = ops.cutlass_scaled_mm(
qinput, qinput,
weight, weight,
...@@ -493,7 +493,7 @@ def apply_fp8_linear( ...@@ -493,7 +493,7 @@ def apply_fp8_linear(
if cutlass_fp8_supported: if cutlass_fp8_supported:
try: try:
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vLLM cutlass w8a8 fp8 kernel # Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm( output = ops.cutlass_scaled_mm(
qinput, qinput,
weight, weight,
......
...@@ -186,8 +186,8 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -186,8 +186,8 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
It supports multiple scaling factors. Since multiple LoRA adapters may have It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way, different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per LoRA adapter, we can run multiple instead of running rotary embedding kernel per lora, we can run multiple
LoRA adapters in a batched way. lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times. of 1 (default) at all times.
......
...@@ -41,13 +41,13 @@ class BaseLoRABackend: ...@@ -41,13 +41,13 @@ class BaseLoRABackend:
def run_lora_a_sgemm( def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor: ) -> torch.Tensor:
"""Run segment Gemm of LoRA a modules with current backend. """Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args: Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of LoRA weights with shape (num_lora, c * r, input_dim), weights: a set of lora weights with shape (num_lora, c * r, input_dim),
here r is LoRA rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj) here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r usually input_dim is much larger than r
Returns: Returns:
result with shape (s, c * r) result with shape (s, c * r)
...@@ -57,12 +57,12 @@ class BaseLoRABackend: ...@@ -57,12 +57,12 @@ class BaseLoRABackend:
def run_lora_b_sgemm( def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor: ) -> torch.Tensor:
"""Run segment Gemm of LoRA b modules with current backend. """Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args: Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is LoRA rank x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of LoRA weights with shape (num_lora, output_dim, r) weights: a set of lora weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r usually output_dim is much larger than r
Returns: Returns:
result with shape (s, output_dim) result with shape (s, output_dim)
...@@ -77,7 +77,7 @@ class BaseLoRABackend: ...@@ -77,7 +77,7 @@ class BaseLoRABackend:
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run the LoRA pass for QKV Layer. """Run the lora pass for QKV Layer.
Args: Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
...@@ -100,7 +100,7 @@ class BaseLoRABackend: ...@@ -100,7 +100,7 @@ class BaseLoRABackend:
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run the LoRA pass for gate_up_proj, usually attached to MergedColumnParallelLayer. """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args: Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
......
...@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend): ...@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
dtype=x.dtype, dtype=x.dtype,
) )
# Compute LoRA for gate and up proj respectively # Compute lora for gate and up proj respectively
lora_output[:, :output_dim] = self.run_lora_b_sgemm( lora_output[:, :output_dim] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), x=lora_a_output[:, :lora_rank].contiguous(),
weights=gate_up_lora_b[0], weights=gate_up_lora_b[0],
......
...@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if self.lora_backend.fuse_stacked_lora_b: if self.lora_backend.fuse_stacked_lora_b:
assert ( assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1] B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The LoRA rank of q and kv should be the same when enabling fusion of qkv lora_b" ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
......
...@@ -40,7 +40,7 @@ class LoRALayer(nn.Module): ...@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
self.config: LoRAConfig = config self.config: LoRAConfig = config
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
# LoRA weights in cpu. The weights are loaded from checkpoint. # lora weights in cpu. The weights are loaded from checkpoint.
self.weights: Dict[str, torch.Tensor] = {} self.weights: Dict[str, torch.Tensor] = {}
...@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module): ...@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]): def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
# Collect target q/k/v modules. This process is necessary since there might be no LoRA attached to k_proj # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set() target_module = set()
for weight_name in weight_names: for weight_name in weight_names:
if "k_proj" in weight_name: if "k_proj" in weight_name:
...@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module): ...@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
return return
for weight_name in weight_names: for weight_name in weight_names:
# We assume every LoRA adaptor should contain LoRA modules for q_proj # We assume every lora adaptor should contain lora modules for q_proj
if "q_proj" in weight_name: if "q_proj" in weight_name:
q_name = weight_name q_name = weight_name
k_name = weight_name.replace("q_proj", "k_proj") k_name = weight_name.replace("q_proj", "k_proj")
...@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module): ...@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
kv_name = weight_name.replace("q_proj", "kv_proj") kv_name = weight_name.replace("q_proj", "kv_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj") qkv_name = weight_name.replace("q_proj", "qkv_proj")
# If k_proj doesn't have LoRA, initialize it to zero # If k_proj doesn't have lora, initialize it to zero
k_proj_weight = ( k_proj_weight = (
weights[k_name] weights[k_name]
if "k_proj" in target_module if "k_proj" in target_module
......
...@@ -93,14 +93,14 @@ class LoRAManager: ...@@ -93,14 +93,14 @@ class LoRAManager:
# Config of each LoRA adapter # Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {} self.configs: Dict[str, LoRAConfig] = {}
# Target module names in HuggingFace LoRA configs. # Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self.hf_target_names: Set[str] = set() self.hf_target_names: Set[str] = set()
for name, path in self.lora_paths.items(): for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path) self.configs[name] = LoRAConfig(path)
self.hf_target_names.update(self.configs[name].target_modules) self.hf_target_names.update(self.configs[name].target_modules)
# Target LoRA weight names for lora_a and lora_b modules respectively. # Target lora weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")} # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self.lora_weight_names: Set[Tuple[str]] = set( self.lora_weight_names: Set[Tuple[str]] = set(
[get_stacked_name(module) for module in self.hf_target_names] [get_stacked_name(module) for module in self.hf_target_names]
...@@ -119,11 +119,11 @@ class LoRAManager: ...@@ -119,11 +119,11 @@ class LoRAManager:
lora_adapter.initialize_weights() lora_adapter.initialize_weights()
self.loras[name] = lora_adapter self.loras[name] = lora_adapter
# misc LoRA configs # misc lora configs
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()]) self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
if self.lora_backend == "flashinfer": if self.lora_backend == "flashinfer":
# FIXME: remove the restrictions after supporting multi-rank for flashinfer backend # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
scaling = list(self.loras.values())[0].scaling scaling = list(self.loras.values())[0].scaling
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values()) assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
...@@ -144,16 +144,16 @@ class LoRAManager: ...@@ -144,16 +144,16 @@ class LoRAManager:
self.lora_modules, self.lora_modules,
) )
# Initialize target LoRA modules in memory pool # Initialize target lora modules in memory pool
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model) self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
def prepare_lora_batch(self, forward_batch: ForwardBatch): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active LoRAs into LoRA memory pool # load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths) cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras) self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
# set up batch info shared by all LoRA modules # set up batch info shared by all lora modules
bs = forward_batch.batch_size bs = forward_batch.batch_size
if ( if (
...@@ -221,7 +221,7 @@ class LoRAManager: ...@@ -221,7 +221,7 @@ class LoRAManager:
) )
self.lora_backend.set_batch_info(batch_info) self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each LoRA modules # call set_lora_info for each lora modules
for layer_id, modules in self.lora_modules.items(): for layer_id, modules in self.lora_modules.items():
for module_name, module in modules: for module_name, module in modules:
if "qkv_proj" in module_name: if "qkv_proj" in module_name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment