Unverified Commit 3a6e8b6d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] move triton attention kernels into a separate folder (#1379)

parent fbb4754c
...@@ -57,9 +57,9 @@ import pandas as pd ...@@ -57,9 +57,9 @@ import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""For constrained decoding."""
import json import json
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Conversation templates.""" """Conversation chat templates."""
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
......
...@@ -16,11 +16,9 @@ limitations under the License. ...@@ -16,11 +16,9 @@ limitations under the License.
"""Utilities for Huggingface Transformers.""" """Utilities for Huggingface Transformers."""
import contextlib import contextlib
import functools
import json
import os import os
import warnings import warnings
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union from typing import Dict, Optional, Type, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
......
...@@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state ...@@ -22,13 +22,20 @@ from flashinfer.cascade import merge_state
from torch import nn from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.decode_attention import decode_attention_fwd from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.model_runner import global_server_args_dict from sglang.srt.model_executor.model_runner import global_server_args_dict
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
"""
The attention layer implementation.
Now it has two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
...@@ -49,8 +56,10 @@ class RadixAttention(nn.Module): ...@@ -49,8 +56,10 @@ class RadixAttention(nn.Module):
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling self.scaling = scaling
self.layer_id = layer_id self.layer_id = layer_id
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
self.sliding_window_size = sliding_window_size if sliding_window_size else -1 self.sliding_window_size = sliding_window_size if sliding_window_size else -1
# Choose backend
if ( if (
not global_server_args_dict.get("disable_flashinfer", False) not global_server_args_dict.get("disable_flashinfer", False)
and self.qk_head_dim == self.v_head_dim and self.qk_head_dim == self.v_head_dim
...@@ -61,8 +70,6 @@ class RadixAttention(nn.Module): ...@@ -61,8 +70,6 @@ class RadixAttention(nn.Module):
self.extend_forward = self.extend_forward_triton self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim: if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.prefill_attention import context_attention_fwd from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
......
...@@ -29,6 +29,7 @@ import torch.distributed ...@@ -29,6 +29,7 @@ import torch.distributed
import torch.distributed as dist import torch.distributed as dist
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
...@@ -52,7 +53,6 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -52,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
) )
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
......
...@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and ...@@ -15,7 +15,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""ModelRunner runs the forward passes of the models.""" """Meta data for a forward pass."""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
......
...@@ -18,7 +18,6 @@ limitations under the License. ...@@ -18,7 +18,6 @@ limitations under the License.
import gc import gc
import importlib import importlib
import importlib.resources import importlib.resources
import json
import logging import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
...@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader import get_model ...@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
...@@ -53,7 +53,6 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -53,7 +53,6 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_config import AttentionArch, ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
......
...@@ -6,8 +6,11 @@ from flashinfer import ( ...@@ -6,8 +6,11 @@ from flashinfer import (
) )
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from sglang.srt.layers.extend_attention import extend_attention_fwd, redundant_attention
from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.layers.triton_attention.extend_attention import (
extend_attention_fwd,
redundant_attention,
)
flashinfer_prefill_wrapper = None flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None flashinfer_decode_wrapper = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment