"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "99a02f07763c39ad6cbfbd53d741a3aac839b684"
Unverified Commit fb1f28cb authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up the comments and names under python/sglang/srt/layers (#1047)

parent fb7421db
...@@ -11,6 +11,8 @@ See the License for the specific language governing permissions and ...@@ -11,6 +11,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Fused operators for activation layers."""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
......
...@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""
Memory-efficient attention for decoding.
"""
# Adapted from # Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
...@@ -194,7 +198,7 @@ def _fwd_kernel_stage2( ...@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
tl.store(out_ptrs, acc) tl.store(out_ptrs, acc)
def _token_att_m_fwd( def _decode_att_m_fwd(
q, q,
k_buffer, k_buffer,
att_out, att_out,
...@@ -254,7 +258,7 @@ def _token_att_m_fwd( ...@@ -254,7 +258,7 @@ def _token_att_m_fwd(
) )
def _token_softmax_reducev_fwd( def _decode_softmax_reducev_fwd(
logics, logics,
v_buffer, v_buffer,
o, o,
...@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd( ...@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
) )
def token_attention_fwd( def decode_attention_fwd(
q, q,
k_buffer, k_buffer,
v_buffer, v_buffer,
...@@ -312,7 +316,7 @@ def token_attention_fwd( ...@@ -312,7 +316,7 @@ def token_attention_fwd(
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
) )
_token_att_m_fwd( _decode_att_m_fwd(
q, q,
k_buffer, k_buffer,
att_m, att_m,
...@@ -324,7 +328,7 @@ def token_attention_fwd( ...@@ -324,7 +328,7 @@ def token_attention_fwd(
sm_scale, sm_scale,
logit_cap, logit_cap,
) )
_token_softmax_reducev_fwd( _decode_softmax_reducev_fwd(
att_m, att_m,
v_buffer, v_buffer,
o, o,
......
...@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""
Memory-efficient attention for prefill.
It supporst page size = 1 and prefill with KV cache (i.e. extend).
"""
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.prefill_attention import context_attention_fwd
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Fused operators for normalization layers."""
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
......
This diff is collapsed.
...@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""
Memory-efficient attention for prefill.
It supporst page size = 1.
"""
# Adapted from # Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import torch import torch
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# temporarily adapted from vLLM
# FIXME: in progress of refactoring the model loader
from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"fp8": Fp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
return QUANTIZATION_METHODS[quantization]
__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]
This diff is collapsed.
...@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state ...@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state
from torch import nn from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.decode_attention import decode_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.model_runner import global_server_args_dict from sglang.srt.model_executor.model_runner import global_server_args_dict
...@@ -95,7 +95,7 @@ class RadixAttention(nn.Module): ...@@ -95,7 +95,7 @@ class RadixAttention(nn.Module):
o = torch.empty_like(q) o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
token_attention_fwd( decode_attention_fwd(
q.view(-1, self.tp_q_head_num, self.qk_head_dim), q.view(-1, self.tp_q_head_num, self.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment