Unverified Commit 5d638c92 authored by Zhang, Liangang's avatar Zhang, Liangang Committed by GitHub
Browse files

[Feature, Hardware] Enable SGLang on XPU GPUs via PyTorch (#1480)

parent e37cdab0
...@@ -20,16 +20,25 @@ dependencies = [ ...@@ -20,16 +20,25 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart", "packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "torchao", "uvicorn", "uvloop", "zmq", "torchao", "uvicorn", "uvloop", "zmq",
"vllm==0.5.5", "outlines>=0.0.44", "modelscope"] "outlines>=0.0.44", "modelscope"]
torch = ["torch"]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
vllm = ["vllm==0.5.5"]
srt = ["sglang[runtime_common]", "torch", "vllm"]
srt_xpu = ["sglang[runtime_common]"]
openai = ["openai>=1.0", "tiktoken"] openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"] anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"] test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate", "peft"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"] dev = ["sglang[all]", "sglang[test]"]
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
[project.urls] [project.urls]
"Homepage" = "https://github.com/sgl-project/sglang" "Homepage" = "https://github.com/sgl-project/sglang"
......
...@@ -288,8 +288,15 @@ def correctness_test( ...@@ -288,8 +288,15 @@ def correctness_test(
rank_print(tokenizer.decode(output_ids[i]), "\n") rank_print(tokenizer.decode(output_ids[i]), "\n")
def synchronize(device):
if device == "cuda":
torch.cuda.synchronize()
elif device == "xpu":
torch.xpu.synchronize()
def latency_test_run_once( def latency_test_run_once(
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
): ):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size: if batch_size > max_batch_size:
...@@ -312,10 +319,10 @@ def latency_test_run_once( ...@@ -312,10 +319,10 @@ def latency_test_run_once(
tot_latency = 0 tot_latency = 0
# Prefill # Prefill
torch.cuda.synchronize() synchronize(device)
tic = time.time() tic = time.time()
next_token_ids, _, batch = extend(reqs, model_runner) next_token_ids, _, batch = extend(reqs, model_runner)
torch.cuda.synchronize() synchronize(device)
prefill_latency = time.time() - tic prefill_latency = time.time() - tic
tot_latency += prefill_latency tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency throughput = input_len * batch_size / prefill_latency
...@@ -328,10 +335,10 @@ def latency_test_run_once( ...@@ -328,10 +335,10 @@ def latency_test_run_once(
# Decode # Decode
decode_latencies = [] decode_latencies = []
for i in range(output_len - 1): for i in range(output_len - 1):
torch.cuda.synchronize() synchronize(device)
tic = time.time() tic = time.time()
next_token_ids, _ = decode(next_token_ids, batch, model_runner) next_token_ids, _ = decode(next_token_ids, batch, model_runner)
torch.cuda.synchronize() synchronize(device)
latency = time.time() - tic latency = time.time() - tic
tot_latency += latency tot_latency += latency
throughput = batch_size / latency throughput = batch_size / latency
...@@ -387,6 +394,7 @@ def latency_test( ...@@ -387,6 +394,7 @@ def latency_test(
bench_args.batch_size[0], bench_args.batch_size[0],
bench_args.input_len[0], bench_args.input_len[0],
8, # shorter decoding to speed up the warmup 8, # shorter decoding to speed up the warmup
server_args.device,
) )
rank_print("Benchmark ...") rank_print("Benchmark ...")
...@@ -397,7 +405,14 @@ def latency_test( ...@@ -397,7 +405,14 @@ def latency_test(
): ):
reqs = prepare_synthetic_inputs_for_latency_test(bs, il) reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
ret = latency_test_run_once( ret = latency_test_run_once(
bench_args.run_name, model_runner, rank_print, reqs, bs, il, ol bench_args.run_name,
model_runner,
rank_print,
reqs,
bs,
il,
ol,
server_args.device,
) )
if ret is not None: if ret is not None:
result_list.append(ret) result_list.append(ret)
......
...@@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend): ...@@ -40,6 +40,8 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_max_seq_len = model_runner.model_config.context_len self.cuda_graph_max_seq_len = model_runner.model_config.context_len
self.device = model_runner.device
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend.""" """Init auxiliary variables for triton attention backend."""
...@@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -51,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
attn_logits = torch.empty( attn_logits = torch.empty(
(self.num_head, total_num_tokens), (self.num_head, total_num_tokens),
dtype=self.reduce_dtype, dtype=self.reduce_dtype,
device="cuda", device=self.device,
) )
max_seq_len = torch.max(forward_batch.seq_lens).item() max_seq_len = torch.max(forward_batch.seq_lens).item()
...@@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -67,7 +69,7 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
self.cuda_graph_start_loc = torch.zeros( self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device="cuda" (max_bs,), dtype=torch.int32, device=self.device
) )
self.cuda_graph_attn_logits = torch.empty( self.cuda_graph_attn_logits = torch.empty(
( (
......
...@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( ...@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd, context_attention_fwd,
) )
CUDA_CAPABILITY = torch.cuda.get_device_capability() is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit @triton.jit
...@@ -286,12 +288,12 @@ def extend_attention_fwd( ...@@ -286,12 +288,12 @@ def extend_attention_fwd(
BLOCK_DPE = 0 BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv) BLOCK_DV = triton.next_power_of_2(Lv)
if CUDA_CAPABILITY[0] >= 9: if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256: if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64) BLOCK_M, BLOCK_N = (128, 64)
else: else:
BLOCK_M, BLOCK_N = (32, 64) BLOCK_M, BLOCK_N = (32, 64)
elif CUDA_CAPABILITY[0] >= 8: elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
if Lq <= 128: if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128) BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256: elif Lq <= 256:
......
...@@ -24,7 +24,9 @@ import torch ...@@ -24,7 +24,9 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
CUDA_CAPABILITY = torch.cuda.get_device_capability() is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit @triton.jit
...@@ -145,7 +147,7 @@ def _fwd_kernel( ...@@ -145,7 +147,7 @@ def _fwd_kernel(
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
if CUDA_CAPABILITY[0] >= 8: if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
BLOCK = 128 BLOCK = 128
else: else:
BLOCK = 64 BLOCK = 64
......
...@@ -118,7 +118,7 @@ class ForwardBatch: ...@@ -118,7 +118,7 @@ class ForwardBatch:
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
model_runner: ModelRunner, model_runner: ModelRunner,
): ):
device = "cuda" device = model_runner.device
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
......
...@@ -138,6 +138,7 @@ class ModelRunner: ...@@ -138,6 +138,7 @@ class ModelRunner:
self.init_attention_backend() self.init_attention_backend()
self.init_cuda_graphs() self.init_cuda_graphs()
else: else:
self.cuda_graph_runner = None
self.init_attention_backend() self.init_attention_backend()
def init_torch_distributed(self): def init_torch_distributed(self):
...@@ -146,6 +147,11 @@ class ModelRunner: ...@@ -146,6 +147,11 @@ class ModelRunner:
if self.device == "cuda": if self.device == "cuda":
torch.cuda.set_device(self.gpu_id) torch.cuda.set_device(self.gpu_id)
backend = "nccl" backend = "nccl"
# ToDO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
elif self.device == "xpu":
torch.xpu.set_device(self.gpu_id)
backend = "gloo"
if not self.server_args.enable_p2p_check: if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id) monkey_patch_vllm_p2p_access_check(self.gpu_id)
......
...@@ -242,7 +242,7 @@ class ServerArgs: ...@@ -242,7 +242,7 @@ class ServerArgs:
"--device", "--device",
type=str, type=str,
default="cuda", default="cuda",
choices=["cuda"], choices=["cuda", "xpu"],
help="The device type.", help="The device type.",
) )
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment