Unverified Commit e8e18dcd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "fix some typos" (#6244)

parent bad7c26f
......@@ -20,7 +20,7 @@ class AttentionBackend(ABC):
raise NotImplementedError()
def init_cuda_graph_state(self, max_bs: int):
"""Init the global shared states for CUDA graph."""
"""Init the global shared states for cuda graph."""
raise NotImplementedError()
def init_forward_metadata_capture_cuda_graph(
......@@ -33,7 +33,7 @@ class AttentionBackend(ABC):
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Init the metadata for a forward pass for capturing a CUDA graph."""
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
def init_forward_metadata_replay_cuda_graph(
......@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Init the metadata for a forward pass for replaying a CUDA graph."""
"""Init the metadata for a forward pass for replaying a cuda graph."""
raise NotImplementedError()
def get_cuda_graph_seq_len_fill_value(self):
......
......@@ -15,7 +15,7 @@ if TYPE_CHECKING:
class DoubleSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of CUDA context
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
extend_attention_fwd,
flash_decode_attention_fwd,
......
......@@ -664,7 +664,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr = kv_indptr[: bs + 1]
if wrapper.is_cuda_graph_enabled:
# Directly write to the CUDA graph input buffer
# Directly write to the cuda graph input buffer
kv_indices = wrapper._paged_kv_indices_buf
else:
kv_indices = torch.empty(
......@@ -1173,7 +1173,7 @@ def fast_decode_plan(
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the CUDA graph buffers.
- Remove unnecessary device-to-device copy for the cuda graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size = len(last_page_len)
......
......@@ -874,7 +874,7 @@ def fast_mla_decode_plan(
) -> None:
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
for skipping the stream synchronization in original plan function during
CUDA graph replaying.
cuda graph replaying.
"""
self._causal = causal
self._page_size = page_size
......
......@@ -92,7 +92,7 @@ class TritonAttnBackend(AttentionBackend):
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
# Lazy import to avoid the initialization of CUDA context
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd,
)
......
......@@ -257,7 +257,7 @@ class VisionFlash3Attention(nn.Module):
**kwargs,
):
if not _is_cuda:
raise Exception("VisionFlash3Attention is only available for CUDA")
raise Exception("VisionFlash3Attention is only available for cuda")
super().__init__()
def forward(
......
......@@ -237,7 +237,7 @@ def dp_scatter(
forward_batch: ForwardBatch,
):
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for CUDA graph
# since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0)
......
......@@ -166,7 +166,7 @@ class LogitsMetadata:
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
if self.global_num_tokens_for_logprob_cpu is None:
# we are capturing CUDA graph
# we are capturing cuda graph
return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
......
......@@ -38,7 +38,7 @@ try:
except ImportError:
VLLM_AVAILABLE = False
# Define empty classes as placeholders when vLLM is not available
# Define empty classes as placeholders when vllm is not available
class DummyConfig:
def override_quantization_method(self, *args, **kwargs):
return None
......@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
raise ValueError(
f"{quantization} quantization requires some operators from vllm. "
"Please install vLLM by `pip install vllm==0.8.4`"
"Please install vllm by `pip install vllm==0.8.4`"
)
return QUANTIZATION_METHODS[quantization]
......@@ -231,7 +231,7 @@ original_isinstance = builtins.isinstance
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
"""
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize SGLang layers
can recognize sglang layers
"""
if not VLLM_AVAILABLE:
return
......@@ -267,7 +267,7 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert SGLang arguments to vLLM arguments.
Convert sglang arguments to vllm arguments.
"""
original_apply = class_obj.apply
sig = inspect.signature(original_apply)
......@@ -329,6 +329,6 @@ def monkey_patch_quant_configs():
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
# Only apply monkey patches if vLLM is available
# Only apply monkey patches if vllm is available
if VLLM_AVAILABLE:
monkey_patch_quant_configs()
......@@ -208,7 +208,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid CUDA graph capturing issue
# Use torch Parameter to avoid cuda graph capturing issue
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
......
......@@ -363,7 +363,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE:
raise ImportError(
"vLLM is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vLLM"
"vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
)
if (
self.quant_format == CompressionFormat.marlin_24.value
......@@ -409,7 +409,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_fp8_w8a16(weight_quant, input_quant):
if not VLLM_AVAILABLE:
raise ImportError(
"vLLM is not installed, to use CompressedTensorsW8A16Fp8, please install vLLM"
"vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
)
is_static_input_scheme = input_quant and not input_quant.dynamic
return CompressedTensorsW8A16Fp8(
......@@ -491,7 +491,7 @@ class CompressedTensorsConfig(QuantizationConfig):
):
if not VLLM_AVAILABLE:
raise ImportError(
"vLLM is not installed, to use CompressedTensors24, please install vLLM"
"vllm is not installed, to use CompressedTensors24, please install vllm"
)
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
......
......@@ -65,7 +65,7 @@ class CompressedTensorsMoEMethod:
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE:
raise ImportError(
"vLLM is not installed, to use CompressedTensorsWNA16MoEMethod, please install vLLM."
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
)
return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
......
......@@ -27,10 +27,10 @@ except ImportError:
MARLIN_FP8_AVAILABLE = False
def apply_fp8_marlin_linear(*args, **kwargs):
raise ImportError("vLLM is not installed")
raise ImportError("vllm is not installed")
def prepare_fp8_layer_for_marlin(*args, **kwargs):
raise ImportError("vLLM is not installed")
raise ImportError("vllm is not installed")
__all__ = ["CompressedTensorsW8A16Fp8"]
......@@ -45,7 +45,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
if not MARLIN_FP8_AVAILABLE:
raise ImportError(
"vLLM is not installed. To use CompressedTensorsW8A16Fp8, please install vLLM"
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
)
@classmethod
......
......@@ -357,7 +357,7 @@ def apply_fp8_linear(
# Fused GEMM_DQ
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vLLM cutlass w8a8 fp8 kernel
# Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm(
qinput,
weight,
......@@ -493,7 +493,7 @@ def apply_fp8_linear(
if cutlass_fp8_supported:
try:
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
# Fall back to vLLM cutlass w8a8 fp8 kernel
# Fall back to vllm cutlass w8a8 fp8 kernel
output = ops.cutlass_scaled_mm(
qinput,
weight,
......
......@@ -186,8 +186,8 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per LoRA adapter, we can run multiple
LoRA adapters in a batched way.
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
......
......@@ -41,13 +41,13 @@ class BaseLoRABackend:
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of LoRA a modules with current backend.
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of LoRA weights with shape (num_lora, c * r, input_dim),
here r is LoRA rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
Returns:
result with shape (s, c * r)
......@@ -57,12 +57,12 @@ class BaseLoRABackend:
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of LoRA b modules with current backend.
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is LoRA rank
weights: a set of LoRA weights with shape (num_lora, output_dim, r)
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
......@@ -77,7 +77,7 @@ class BaseLoRABackend:
*args,
**kwargs,
) -> torch.Tensor:
"""Run the LoRA pass for QKV Layer.
"""Run the lora pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
......@@ -100,7 +100,7 @@ class BaseLoRABackend:
*args,
**kwargs,
) -> torch.Tensor:
"""Run the LoRA pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
......
......@@ -117,7 +117,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
dtype=x.dtype,
)
# Compute LoRA for gate and up proj respectively
# Compute lora for gate and up proj respectively
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(),
weights=gate_up_lora_b[0],
......
......@@ -198,7 +198,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if self.lora_backend.fuse_stacked_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The LoRA rank of q and kv should be the same when enabling fusion of qkv lora_b"
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
......
......@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
self.config: LoRAConfig = config
self.base_hf_config: AutoConfig = base_hf_config
# LoRA weights in cpu. The weights are loaded from checkpoint.
# lora weights in cpu. The weights are loaded from checkpoint.
self.weights: Dict[str, torch.Tensor] = {}
......@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
# Collect target q/k/v modules. This process is necessary since there might be no LoRA attached to k_proj
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set()
for weight_name in weight_names:
if "k_proj" in weight_name:
......@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
return
for weight_name in weight_names:
# We assume every LoRA adaptor should contain LoRA modules for q_proj
# We assume every lora adaptor should contain lora modules for q_proj
if "q_proj" in weight_name:
q_name = weight_name
k_name = weight_name.replace("q_proj", "k_proj")
......@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
kv_name = weight_name.replace("q_proj", "kv_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj")
# If k_proj doesn't have LoRA, initialize it to zero
# If k_proj doesn't have lora, initialize it to zero
k_proj_weight = (
weights[k_name]
if "k_proj" in target_module
......
......@@ -93,14 +93,14 @@ class LoRAManager:
# Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {}
# Target module names in HuggingFace LoRA configs.
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self.hf_target_names: Set[str] = set()
for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path)
self.hf_target_names.update(self.configs[name].target_modules)
# Target LoRA weight names for lora_a and lora_b modules respectively.
# Target lora weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self.lora_weight_names: Set[Tuple[str]] = set(
[get_stacked_name(module) for module in self.hf_target_names]
......@@ -119,11 +119,11 @@ class LoRAManager:
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# misc LoRA configs
# misc lora configs
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
if self.lora_backend == "flashinfer":
# FIXME: remove the restrictions after supporting multi-rank for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
scaling = list(self.loras.values())[0].scaling
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
......@@ -144,16 +144,16 @@ class LoRAManager:
self.lora_modules,
)
# Initialize target LoRA modules in memory pool
# Initialize target lora modules in memory pool
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active LoRAs into LoRA memory pool
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
# set up batch info shared by all LoRA modules
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
if (
......@@ -221,7 +221,7 @@ class LoRAManager:
)
self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each LoRA modules
# call set_lora_info for each lora modules
for layer_id, modules in self.lora_modules.items():
for module_name, module in modules:
if "qkv_proj" in module_name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment