Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ffe4aaee
Unverified
Commit
ffe4aaee
authored
Jan 16, 2024
by
Ying Sheng
Committed by
GitHub
Jan 16, 2024
Browse files
Fix for T4 GPUs (#16)
Co-authored-by:
Lianmin Zheng
<
lianminzheng@gmail.com
>
parent
5b27a1dc
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
68 additions
and
6 deletions
+68
-6
README.md
README.md
+9
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/__init__.py
python/sglang/__init__.py
+1
-1
python/sglang/srt/layers/context_flashattention_nopad.py
python/sglang/srt/layers/context_flashattention_nopad.py
+8
-1
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+47
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+2
-1
No files found.
README.md
View file @
ffe4aaee
...
@@ -32,6 +32,10 @@ pip install --upgrade pip
...
@@ -32,6 +32,10 @@ pip install --upgrade pip
pip install -e "python[all]"
pip install -e "python[all]"
```
```
### Notes
-
If you are using older GPUs (NVIDIA T4, V100), please use
`pip install "triton>=2.2.0"`
to avoid some bugs in the triton compiler
-
If you only need to use the OpenAI backend, you can avoid installing other dependencies by using
`pip install sglang[openai]`
## Quick Start
## Quick Start
The example below shows how to use sglang to answer a mulit-turn question.
The example below shows how to use sglang to answer a mulit-turn question.
...
@@ -197,7 +201,7 @@ for out in state.text_iter():
...
@@ -197,7 +201,7 @@ for out in state.text_iter():
## Backend: SGLang Runtime (SRT)
## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
However, it can also be used as a standalone API server.
However, it can also be used as a standalone API server.
In this case, the
[
RadixAttention
](
https://arxiv.org/abs/2312.07104
)
can still greatly accelerate many use cases.
In this case, the
[
RadixAttention
](
https://arxiv.org/abs/2312.07104
)
can still greatly accelerate many use cases
with automatic KV cache reuse
.
### Usage
### Usage
Launch a server
Launch a server
...
@@ -221,6 +225,10 @@ curl http://localhost:30000/v1/completions \
...
@@ -221,6 +225,10 @@ curl http://localhost:30000/v1/completions \
```
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
```
```
-
If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of
`--mem-fraction-static`
. The default value is
`0.9`
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
### Supported Models
### Supported Models
-
Llama
-
Llama
...
...
python/pyproject.toml
View file @
ffe4aaee
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
[project]
name
=
"sglang"
name
=
"sglang"
version
=
"0.1.
3
"
version
=
"0.1.
4
"
description
=
"A structured generation langauge for LLMs."
description
=
"A structured generation langauge for LLMs."
readme
=
"README.md"
readme
=
"README.md"
requires-python
=
">=3.8"
requires-python
=
">=3.8"
...
...
python/sglang/__init__.py
View file @
ffe4aaee
__version__
=
"0.1.
3
"
__version__
=
"0.1.
4
"
from
sglang.api
import
*
from
sglang.api
import
*
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
python/sglang/srt/layers/context_flashattention_nopad.py
View file @
ffe4aaee
...
@@ -6,6 +6,9 @@ import triton.language as tl
...
@@ -6,6 +6,9 @@ import triton.language as tl
from
sglang.srt.utils
import
wrap_kernel_launcher
from
sglang.srt.utils
import
wrap_kernel_launcher
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel
(
def
_fwd_kernel
(
Q
,
Q
,
...
@@ -120,7 +123,11 @@ cached_kernel = None
...
@@ -120,7 +123,11 @@ cached_kernel = None
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
):
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
):
if
CUDA_CAPABILITY
[
0
]
>=
8
:
BLOCK
=
128
BLOCK
=
128
else
:
BLOCK
=
64
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
}
assert
Lk
in
{
16
,
32
,
64
,
128
}
...
...
python/sglang/srt/layers/extend_attention.py
View file @
ffe4aaee
...
@@ -2,6 +2,10 @@ import torch
...
@@ -2,6 +2,10 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
from
sglang.srt.utils
import
wrap_kernel_launcher
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
@
triton
.
jit
...
@@ -153,6 +157,9 @@ def _fwd_kernel(
...
@@ -153,6 +157,9 @@ def _fwd_kernel(
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
])
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
])
cached_kernel
=
None
def
extend_attention_fwd
(
def
extend_attention_fwd
(
q_extend
,
q_extend
,
k_extend
,
k_extend
,
...
@@ -175,7 +182,11 @@ def extend_attention_fwd(
...
@@ -175,7 +182,11 @@ def extend_attention_fwd(
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
"""
if
CUDA_CAPABILITY
[
0
]
>=
8
:
BLOCK_M
,
BLOCK_N
=
128
,
128
BLOCK_M
,
BLOCK_N
=
128
,
128
else
:
BLOCK_M
,
BLOCK_N
=
64
,
64
Lq
,
Lk
,
Lv
,
Lo
=
(
Lq
,
Lk
,
Lv
,
Lo
=
(
q_extend
.
shape
[
-
1
],
q_extend
.
shape
[
-
1
],
k_extend
.
shape
[
-
1
],
k_extend
.
shape
[
-
1
],
...
@@ -193,6 +204,40 @@ def extend_attention_fwd(
...
@@ -193,6 +204,40 @@ def extend_attention_fwd(
num_warps
=
4
if
Lk
<=
64
else
8
num_warps
=
4
if
Lk
<=
64
else
8
num_stages
=
1
num_stages
=
1
global
cached_kernel
if
cached_kernel
:
cached_kernel
(
grid
,
num_warps
,
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_start_loc_extend
,
b_seq_len_extend
,
sm_scale
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
1
),
k_extend
.
stride
(
0
),
k_extend
.
stride
(
1
),
v_extend
.
stride
(
0
),
v_extend
.
stride
(
1
),
o_extend
.
stride
(
0
),
o_extend
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
)
return
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
](
q_extend
,
q_extend
,
k_extend
,
k_extend
,
...
@@ -226,6 +271,7 @@ def extend_attention_fwd(
...
@@ -226,6 +271,7 @@ def extend_attention_fwd(
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
)
)
cached_kernel
=
wrap_kernel_launcher
(
_fwd_kernel
)
def
redundant_attention
(
def
redundant_attention
(
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
ffe4aaee
...
@@ -5,6 +5,7 @@ import time
...
@@ -5,6 +5,7 @@ import time
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
warnings
import
numpy
as
np
import
numpy
as
np
import
rpyc
import
rpyc
...
@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service):
+
self
.
tree_cache
.
evictable_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
if
available_size
!=
self
.
max_total_num_token
:
if
available_size
!=
self
.
max_total_num_token
:
logger
.
warning
(
warning
s
.
warn
(
"Warning: "
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_token=
{
self
.
max_total_num_token
}
\n
"
f
"available_size=
{
available_size
}
, max_total_num_token=
{
self
.
max_total_num_token
}
\n
"
"KV cache pool leak detected!"
"KV cache pool leak detected!"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment