Unverified Commit ffe4aaee authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix for T4 GPUs (#16)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent 5b27a1dc
......@@ -32,6 +32,10 @@ pip install --upgrade pip
pip install -e "python[all]"
```
### Notes
- If you are using older GPUs (NVIDIA T4, V100), please use `pip install "triton>=2.2.0"` to avoid some bugs in the triton compiler
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install sglang[openai]`
## Quick Start
The example below shows how to use sglang to answer a mulit-turn question.
......@@ -197,7 +201,7 @@ for out in state.text_iter():
## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
However, it can also be used as a standalone API server.
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases.
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
### Usage
Launch a server
......@@ -221,6 +225,10 @@ curl http://localhost:30000/v1/completions \
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
```
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
### Supported Models
- Llama
......
......@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sglang"
version = "0.1.3"
version = "0.1.4"
description = "A structured generation langauge for LLMs."
readme = "README.md"
requires-python = ">=3.8"
......
__version__ = "0.1.3"
__version__ = "0.1.4"
from sglang.api import *
from sglang.global_config import global_config
......@@ -6,6 +6,9 @@ import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
def _fwd_kernel(
Q,
......@@ -120,7 +123,11 @@ cached_kernel = None
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128
if CUDA_CAPABILITY[0] >= 8:
BLOCK = 128
else:
BLOCK = 64
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
......
......@@ -2,6 +2,10 @@ import torch
import triton
import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
......@@ -153,6 +157,9 @@ def _fwd_kernel(
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
cached_kernel = None
def extend_attention_fwd(
q_extend,
k_extend,
......@@ -175,7 +182,11 @@ def extend_attention_fwd(
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
BLOCK_M, BLOCK_N = 128, 128
if CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = 128, 128
else:
BLOCK_M, BLOCK_N = 64, 64
Lq, Lk, Lv, Lo = (
q_extend.shape[-1],
k_extend.shape[-1],
......@@ -193,6 +204,40 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel[grid](
q_extend,
k_extend,
......@@ -226,6 +271,7 @@ def extend_attention_fwd(
num_warps=num_warps,
num_stages=num_stages,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
def redundant_attention(
......
......@@ -5,6 +5,7 @@ import time
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union
import warnings
import numpy as np
import rpyc
......@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service):
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_token:
logger.warning(
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment