Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5d638c92
Unverified
Commit
5d638c92
authored
Oct 13, 2024
by
Zhang, Liangang
Committed by
GitHub
Oct 12, 2024
Browse files
[Feature, Hardware] Enable SGLang on XPU GPUs via PyTorch (#1480)
parent
e37cdab0
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
55 additions
and
19 deletions
+55
-19
python/pyproject.toml
python/pyproject.toml
+13
-4
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+21
-6
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-2
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+5
-3
python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
...lang/srt/layers/attention/triton_ops/prefill_attention.py
+4
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
python/pyproject.toml
View file @
5d638c92
...
@@ -20,16 +20,25 @@ dependencies = [
...
@@ -20,16 +20,25 @@ dependencies = [
]
]
[project.optional-dependencies]
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"decord"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
runtime_common
=
[
"aiohttp"
,
"decord"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"packaging"
,
"pillow"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"packaging"
,
"pillow"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"torch"
,
"torchao"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"torchao"
,
"uvicorn"
,
"uvloop"
,
"zmq"
,
"vllm==0.5.5"
,
"outlines>=0.0.44"
,
"modelscope"
]
"outlines>=0.0.44"
,
"modelscope"
]
torch
=
["torch"]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
vllm
=
["vllm==0.5.5"]
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm
"]
srt_xpu
=
["sglang[runtime_common]"]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
litellm
=
["litellm>=1.0.0"]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
,
"sentence_transformers"
,
"accelerate"
,
"peft"
]
test
=
[
"jsonlines"
,
"matplotlib"
,
"pandas"
,
"sentence_transformers"
,
"accelerate"
,
"peft"
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_xpu
=
["sglang[srt_xpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
dev_xpu
=
["sglang[all_xpu]
", "
sglang
[test]"]
[project.urls]
[project.urls]
"Homepage"
=
"https://github.com/sgl-project/sglang"
"Homepage"
=
"https://github.com/sgl-project/sglang"
...
...
python/sglang/bench_latency.py
View file @
5d638c92
...
@@ -288,8 +288,15 @@ def correctness_test(
...
@@ -288,8 +288,15 @@ def correctness_test(
rank_print
(
tokenizer
.
decode
(
output_ids
[
i
]),
"
\n
"
)
rank_print
(
tokenizer
.
decode
(
output_ids
[
i
]),
"
\n
"
)
def
synchronize
(
device
):
if
device
==
"cuda"
:
torch
.
cuda
.
synchronize
()
elif
device
==
"xpu"
:
torch
.
xpu
.
synchronize
()
def
latency_test_run_once
(
def
latency_test_run_once
(
run_name
,
model_runner
,
rank_print
,
reqs
,
batch_size
,
input_len
,
output_len
run_name
,
model_runner
,
rank_print
,
reqs
,
batch_size
,
input_len
,
output_len
,
device
):
):
max_batch_size
=
model_runner
.
max_total_num_tokens
//
(
input_len
+
output_len
)
max_batch_size
=
model_runner
.
max_total_num_tokens
//
(
input_len
+
output_len
)
if
batch_size
>
max_batch_size
:
if
batch_size
>
max_batch_size
:
...
@@ -312,10 +319,10 @@ def latency_test_run_once(
...
@@ -312,10 +319,10 @@ def latency_test_run_once(
tot_latency
=
0
tot_latency
=
0
# Prefill
# Prefill
torch
.
cuda
.
synchronize
()
synchronize
(
device
)
tic
=
time
.
time
()
tic
=
time
.
time
()
next_token_ids
,
_
,
batch
=
extend
(
reqs
,
model_runner
)
next_token_ids
,
_
,
batch
=
extend
(
reqs
,
model_runner
)
torch
.
cuda
.
synchronize
()
synchronize
(
device
)
prefill_latency
=
time
.
time
()
-
tic
prefill_latency
=
time
.
time
()
-
tic
tot_latency
+=
prefill_latency
tot_latency
+=
prefill_latency
throughput
=
input_len
*
batch_size
/
prefill_latency
throughput
=
input_len
*
batch_size
/
prefill_latency
...
@@ -328,10 +335,10 @@ def latency_test_run_once(
...
@@ -328,10 +335,10 @@ def latency_test_run_once(
# Decode
# Decode
decode_latencies
=
[]
decode_latencies
=
[]
for
i
in
range
(
output_len
-
1
):
for
i
in
range
(
output_len
-
1
):
torch
.
cuda
.
synchronize
()
synchronize
(
device
)
tic
=
time
.
time
()
tic
=
time
.
time
()
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
next_token_ids
,
_
=
decode
(
next_token_ids
,
batch
,
model_runner
)
torch
.
cuda
.
synchronize
()
synchronize
(
device
)
latency
=
time
.
time
()
-
tic
latency
=
time
.
time
()
-
tic
tot_latency
+=
latency
tot_latency
+=
latency
throughput
=
batch_size
/
latency
throughput
=
batch_size
/
latency
...
@@ -387,6 +394,7 @@ def latency_test(
...
@@ -387,6 +394,7 @@ def latency_test(
bench_args
.
batch_size
[
0
],
bench_args
.
batch_size
[
0
],
bench_args
.
input_len
[
0
],
bench_args
.
input_len
[
0
],
8
,
# shorter decoding to speed up the warmup
8
,
# shorter decoding to speed up the warmup
server_args
.
device
,
)
)
rank_print
(
"Benchmark ..."
)
rank_print
(
"Benchmark ..."
)
...
@@ -397,7 +405,14 @@ def latency_test(
...
@@ -397,7 +405,14 @@ def latency_test(
):
):
reqs
=
prepare_synthetic_inputs_for_latency_test
(
bs
,
il
)
reqs
=
prepare_synthetic_inputs_for_latency_test
(
bs
,
il
)
ret
=
latency_test_run_once
(
ret
=
latency_test_run_once
(
bench_args
.
run_name
,
model_runner
,
rank_print
,
reqs
,
bs
,
il
,
ol
bench_args
.
run_name
,
model_runner
,
rank_print
,
reqs
,
bs
,
il
,
ol
,
server_args
.
device
,
)
)
if
ret
is
not
None
:
if
ret
is
not
None
:
result_list
.
append
(
ret
)
result_list
.
append
(
ret
)
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
5d638c92
...
@@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
...
@@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
device
=
self
.
device
,
)
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
...
@@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
(
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
5d638c92
...
@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
...
@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd
,
context_attention_fwd
,
)
)
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
@
triton
.
jit
...
@@ -286,12 +288,12 @@ def extend_attention_fwd(
...
@@ -286,12 +288,12 @@ def extend_attention_fwd(
BLOCK_DPE
=
0
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
if
CUDA_CAPABILITY
[
0
]
>=
9
:
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>=
9
:
if
Lq
<=
256
:
if
Lq
<=
256
:
BLOCK_M
,
BLOCK_N
=
(
128
,
64
)
BLOCK_M
,
BLOCK_N
=
(
128
,
64
)
else
:
else
:
BLOCK_M
,
BLOCK_N
=
(
32
,
64
)
BLOCK_M
,
BLOCK_N
=
(
32
,
64
)
elif
CUDA_CAPABILITY
[
0
]
>=
8
:
elif
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>=
8
:
if
Lq
<=
128
:
if
Lq
<=
128
:
BLOCK_M
,
BLOCK_N
=
(
128
,
128
)
BLOCK_M
,
BLOCK_N
=
(
128
,
128
)
elif
Lq
<=
256
:
elif
Lq
<=
256
:
...
...
python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
View file @
5d638c92
...
@@ -24,7 +24,9 @@ import torch
...
@@ -24,7 +24,9 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
@
triton
.
jit
...
@@ -145,7 +147,7 @@ def _fwd_kernel(
...
@@ -145,7 +147,7 @@ def _fwd_kernel(
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
):
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
):
if
CUDA_CAPABILITY
[
0
]
>=
8
:
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>=
8
:
BLOCK
=
128
BLOCK
=
128
else
:
else
:
BLOCK
=
64
BLOCK
=
64
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
5d638c92
...
@@ -118,7 +118,7 @@ class ForwardBatch:
...
@@ -118,7 +118,7 @@ class ForwardBatch:
batch
:
ModelWorkerBatch
,
batch
:
ModelWorkerBatch
,
model_runner
:
ModelRunner
,
model_runner
:
ModelRunner
,
):
):
device
=
"cuda"
device
=
model_runner
.
device
ret
=
cls
(
ret
=
cls
(
forward_mode
=
batch
.
forward_mode
,
forward_mode
=
batch
.
forward_mode
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5d638c92
...
@@ -138,6 +138,7 @@ class ModelRunner:
...
@@ -138,6 +138,7 @@ class ModelRunner:
self
.
init_attention_backend
()
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
self
.
init_cuda_graphs
()
else
:
else
:
self
.
cuda_graph_runner
=
None
self
.
init_attention_backend
()
self
.
init_attention_backend
()
def
init_torch_distributed
(
self
):
def
init_torch_distributed
(
self
):
...
@@ -146,6 +147,11 @@ class ModelRunner:
...
@@ -146,6 +147,11 @@ class ModelRunner:
if
self
.
device
==
"cuda"
:
if
self
.
device
==
"cuda"
:
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
backend
=
"nccl"
backend
=
"nccl"
# ToDO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
elif
self
.
device
==
"xpu"
:
torch
.
xpu
.
set_device
(
self
.
gpu_id
)
backend
=
"gloo"
if
not
self
.
server_args
.
enable_p2p_check
:
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
...
...
python/sglang/srt/server_args.py
View file @
5d638c92
...
@@ -242,7 +242,7 @@ class ServerArgs:
...
@@ -242,7 +242,7 @@ class ServerArgs:
"--device"
,
"--device"
,
type
=
str
,
type
=
str
,
default
=
"cuda"
,
default
=
"cuda"
,
choices
=
[
"cuda"
],
choices
=
[
"cuda"
,
"xpu"
],
help
=
"The device type."
,
help
=
"The device type."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment