Unverified Commit ffe4aaee authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix for T4 GPUs (#16)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent 5b27a1dc
...@@ -32,6 +32,10 @@ pip install --upgrade pip ...@@ -32,6 +32,10 @@ pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
``` ```
### Notes
- If you are using older GPUs (NVIDIA T4, V100), please use `pip install "triton>=2.2.0"` to avoid some bugs in the triton compiler
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install sglang[openai]`
## Quick Start ## Quick Start
The example below shows how to use sglang to answer a mulit-turn question. The example below shows how to use sglang to answer a mulit-turn question.
...@@ -197,7 +201,7 @@ for out in state.text_iter(): ...@@ -197,7 +201,7 @@ for out in state.text_iter():
## Backend: SGLang Runtime (SRT) ## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend. The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
However, it can also be used as a standalone API server. However, it can also be used as a standalone API server.
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases. In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
### Usage ### Usage
Launch a server Launch a server
...@@ -221,6 +225,10 @@ curl http://localhost:30000/v1/completions \ ...@@ -221,6 +225,10 @@ curl http://localhost:30000/v1/completions \
``` ```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
``` ```
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
### Supported Models ### Supported Models
- Llama - Llama
......
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "sglang" name = "sglang"
version = "0.1.3" version = "0.1.4"
description = "A structured generation langauge for LLMs." description = "A structured generation langauge for LLMs."
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
......
__version__ = "0.1.3" __version__ = "0.1.4"
from sglang.api import * from sglang.api import *
from sglang.global_config import global_config from sglang.global_config import global_config
...@@ -6,6 +6,9 @@ import triton.language as tl ...@@ -6,6 +6,9 @@ import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q, Q,
...@@ -120,7 +123,11 @@ cached_kernel = None ...@@ -120,7 +123,11 @@ cached_kernel = None
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128 if CUDA_CAPABILITY[0] >= 8:
BLOCK = 128
else:
BLOCK = 64
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128} assert Lk in {16, 32, 64, 128}
......
...@@ -2,6 +2,10 @@ import torch ...@@ -2,6 +2,10 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit @triton.jit
...@@ -153,6 +157,9 @@ def _fwd_kernel( ...@@ -153,6 +157,9 @@ def _fwd_kernel(
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
cached_kernel = None
def extend_attention_fwd( def extend_attention_fwd(
q_extend, q_extend,
k_extend, k_extend,
...@@ -175,7 +182,11 @@ def extend_attention_fwd( ...@@ -175,7 +182,11 @@ def extend_attention_fwd(
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
""" """
BLOCK_M, BLOCK_N = 128, 128 if CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = 128, 128
else:
BLOCK_M, BLOCK_N = 64, 64
Lq, Lk, Lv, Lo = ( Lq, Lk, Lv, Lo = (
q_extend.shape[-1], q_extend.shape[-1],
k_extend.shape[-1], k_extend.shape[-1],
...@@ -193,6 +204,40 @@ def extend_attention_fwd( ...@@ -193,6 +204,40 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
num_stages = 1 num_stages = 1
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel[grid]( _fwd_kernel[grid](
q_extend, q_extend,
k_extend, k_extend,
...@@ -226,6 +271,7 @@ def extend_attention_fwd( ...@@ -226,6 +271,7 @@ def extend_attention_fwd(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
def redundant_attention( def redundant_attention(
......
...@@ -5,6 +5,7 @@ import time ...@@ -5,6 +5,7 @@ import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import warnings
import numpy as np import numpy as np
import rpyc import rpyc
...@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service):
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
) )
if available_size != self.max_total_num_token: if available_size != self.max_total_num_token:
logger.warning( warnings.warn(
"Warning: " "Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n" f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!" "KV cache pool leak detected!"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment