Unverified Commit 9c7a4618 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Doc] Steps to add a new attention backend (#8155)

parent 7750b91c
...@@ -56,7 +56,7 @@ jobs: ...@@ -56,7 +56,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
part: [0, 1, 2, 3, 4, 5, 6, 7, 8] part: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
...@@ -69,7 +69,7 @@ jobs: ...@@ -69,7 +69,7 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 9 python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 10
unit-test-backend-2-gpu: unit-test-backend-2-gpu:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
......
...@@ -52,3 +52,31 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti ...@@ -52,3 +52,31 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
```bash ```bash
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
``` ```
## Steps to add a new attention backend
To add a new attention backend, you can learn from the existing backends
(`python/sglang/srt/layers/attention/triton_backend.py`, `python/sglang/srt/layers/attention/flashattention_backend.py`)
and follow the steps below.
1. Run without cuda graph. Support the two forward functions
- forward_extend
- Will be used for prefill, prefill with KV cache, and target verification
- It will be called once per layer
- forward_decode
- Will be used for normal decode, and draft decode
- It will be called once per layer
- init_forward_metadata
- Initialize the class and common metadata shared by all layers
- Call the plan function for optimizations like split_kv
- It will be called once per forward
2. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions
- init_cuda_graph_state
- It will be called once during life time
- Create all common shared buffers
- init_forward_metadata_capture_cuda_graph
- It will be called before capturing a cuda graph
- It is similar to init_forward_metadata but write the medatada to some pre-defined buffers
- init_forward_metadata_replay_cuda_graph
- It will be called before replaying a cuda graph
- This function is in the critical path and needs to be fast
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# ============================================================================== # ==============================================================================
""" """
The definition of objects transferred between different The definition of objects transferred between different
processes (TokenizerManager, DetokenizerManager, Controller). processes (TokenizerManager, DetokenizerManager, Scheduler).
""" """
import copy import copy
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.multimodal.mm_utils import has_valid_data
...@@ -545,7 +545,7 @@ class EmbeddingReqInput: ...@@ -545,7 +545,7 @@ class EmbeddingReqInput:
# The request id. # The request id.
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# Dummy input embeds for compatibility # Dummy input embeds for compatibility
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics) # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
...@@ -953,17 +953,6 @@ class ProfileReqType(Enum): ...@@ -953,17 +953,6 @@ class ProfileReqType(Enum):
STOP_PROFILE = 2 STOP_PROFILE = 2
class ExpertDistributionReq(Enum):
START_RECORD = 1
STOP_RECORD = 2
DUMP_RECORD = 3
@dataclass
class ExpertDistributionReqOutput:
pass
@dataclass @dataclass
class ProfileReq: class ProfileReq:
type: ProfileReqType type: ProfileReqType
...@@ -1013,6 +1002,17 @@ class HealthCheckOutput: ...@@ -1013,6 +1002,17 @@ class HealthCheckOutput:
pass pass
class ExpertDistributionReq(Enum):
START_RECORD = 1
STOP_RECORD = 2
DUMP_RECORD = 3
@dataclass
class ExpertDistributionReqOutput:
pass
@dataclass @dataclass
class Function: class Function:
description: Optional[str] = None description: Optional[str] = None
......
...@@ -155,11 +155,11 @@ suites = { ...@@ -155,11 +155,11 @@ suites = {
"per-commit-2-gpu": [ "per-commit-2-gpu": [
TestFile("models/lora/test_lora_tp.py", 116), TestFile("models/lora/test_lora_tp.py", 116),
TestFile("test_data_parallelism.py", 73), TestFile("test_data_parallelism.py", 73),
TestFile("test_dp_attention.py", 137), TestFile("test_dp_attention.py", 277),
TestFile("test_mla_tp.py", 170), TestFile("test_mla_tp.py", 170),
TestFile("test_patch_torch.py", 19), TestFile("test_patch_torch.py", 19),
TestFile("test_update_weights_from_distributed.py", 103), TestFile("test_update_weights_from_distributed.py", 103),
TestFile("test_release_memory_occupation.py", 44), TestFile("test_release_memory_occupation.py", 127),
], ],
"per-commit-2-gpu-amd": [ "per-commit-2-gpu-amd": [
TestFile("models/lora/test_lora_tp.py", 116), TestFile("models/lora/test_lora_tp.py", 116),
...@@ -170,7 +170,7 @@ suites = { ...@@ -170,7 +170,7 @@ suites = {
], ],
"per-commit-4-gpu": [ "per-commit-4-gpu": [
TestFile("test_local_attn.py", 250), TestFile("test_local_attn.py", 250),
TestFile("test_pp_single_node.py", 150), TestFile("test_pp_single_node.py", 372),
TestFile("test_multi_instance_release_memory_occupation.py", 64), TestFile("test_multi_instance_release_memory_occupation.py", 64),
], ],
"per-commit-4-gpu-deepep": [ "per-commit-4-gpu-deepep": [
...@@ -182,12 +182,12 @@ suites = { ...@@ -182,12 +182,12 @@ suites = {
"per-commit-8-gpu": [ "per-commit-8-gpu": [
# Disabled because it hangs on the CI. # Disabled because it hangs on the CI.
# TestFile("test_moe_ep.py", 181), # TestFile("test_moe_ep.py", 181),
TestFile("test_disaggregation.py", 270), TestFile("test_disaggregation.py", 499),
TestFile("test_disaggregation_different_tp.py", 155), TestFile("test_disaggregation_different_tp.py", 155),
TestFile("test_full_deepseek_v3.py", 463), TestFile("test_full_deepseek_v3.py", 333),
], ],
"per-commit-8-gpu-deepep": [ "per-commit-8-gpu-deepep": [
TestFile("test_deepep_large.py", 485), TestFile("test_deepep_large.py", 338),
], ],
"per-commit-8-gpu-amd": [ "per-commit-8-gpu-amd": [
TestFile("test_full_deepseek_v3.py", 250), TestFile("test_full_deepseek_v3.py", 250),
...@@ -214,11 +214,11 @@ suites = { ...@@ -214,11 +214,11 @@ suites = {
TestFile("test_nightly_gsm8k_eval_amd.py"), TestFile("test_nightly_gsm8k_eval_amd.py"),
], ],
"vllm_dependency_test": [ "vllm_dependency_test": [
TestFile("test_awq.py"), TestFile("test_awq.py", 163),
TestFile("test_bnb.py"), TestFile("test_bnb.py", 5),
TestFile("test_gguf.py", 78), TestFile("test_gguf.py", 96),
TestFile("test_gptqmodel_dynamic.py", 72), TestFile("test_gptqmodel_dynamic.py", 102),
TestFile("test_vllm_dependency.py"), TestFile("test_vllm_dependency.py", 185),
], ],
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment