You need to sign in or sign up before continuing.
Unverified Commit 37ee906f authored by Qun Yang's avatar Qun Yang Committed by GitHub
Browse files

Add more support for intel Gaudi accelerators (#2357)

parent 34b364e0
import argparse
import dataclasses
import sglang as sgl
from sglang.srt.server_args import ServerArgs
def main():
def main(
server_args: ServerArgs,
):
# Sample prompts.
prompts = [
"Hello, my name is",
......@@ -13,7 +19,7 @@ def main():
sampling_params = {"temperature": 0.8, "top_p": 0.95}
# Create an LLM.
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
llm = sgl.Engine(**dataclasses.asdict(server_args))
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
......@@ -25,4 +31,8 @@ def main():
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
main(server_args)
......@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
sampled_index = torch.multinomial(probs_sort, num_samples=1)
# int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids
......@@ -993,7 +993,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
......@@ -1055,7 +1055,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
......@@ -1130,7 +1130,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)
......
......@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
......@@ -73,7 +74,7 @@ class TpModelWorkerClient:
# Launch threads
self.input_queue = Queue()
self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
......@@ -97,7 +98,7 @@ class TpModelWorkerClient:
def forward_thread_func(self):
try:
with torch.cuda.stream(self.forward_stream):
with torch.get_device_module(self.device).stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
......@@ -122,7 +123,7 @@ class TpModelWorkerClient:
# Create event
self.launch_done = threading.Event()
copy_done = torch.cuda.Event()
copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
......@@ -190,7 +191,7 @@ class TpModelWorkerClient:
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
# Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
......
......@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_compiler_backend
logger = logging.getLogger(__name__)
......@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
return select_index.to(self.device, non_blocking=True)
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
else:
......@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_1[loc] = src_1.to(dtype).view(store_dtype)
dst_2[loc] = src_2.to(dtype).view(store_dtype)
......
......@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
@torch.compile
@torch.compile(backend=get_compiler_backend())
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
......
......@@ -25,6 +25,7 @@ import torch
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
get_nvgpu_memory_capacity,
is_flashinfer_available,
is_hip,
......@@ -158,6 +159,8 @@ class ServerArgs:
gpu_mem = get_amdgpu_memory_capacity()
elif torch.cuda.is_available():
gpu_mem = get_nvgpu_memory_capacity()
elif self.device == "hpu":
gpu_mem = get_hpu_memory_capacity()
else:
# GPU memory is not known yet or no GPU is available.
gpu_mem = None
......@@ -194,6 +197,10 @@ class ServerArgs:
self.cuda_graph_max_bs = 160
# Choose kernel backends
if self.device == "hpu":
self.attention_backend = "torch_native"
self.sampling_backend = "pytorch"
if self.attention_backend is None:
self.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
......
......@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory
elif device == "hpu":
num_gpus = torch.hpu.device_count()
assert gpu_id < num_gpus
if torch.hpu.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
"which may cause useless memory allocation for torch HPU context.",
)
free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device(device, gpu_id)
......@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
)
def get_hpu_memory_capacity():
try:
# Run hl-smi and capture the output
result = subprocess.run(
["hl-smi --query | grep 'Total'"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")
# Parse the output to extract memory values in MiB
memory_values = [
float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
]
if not memory_values:
raise ValueError("No GPU memory values found.")
# Return the minimum memory value
return min(memory_values)
except FileNotFoundError:
raise RuntimeError(
"hl-smi not found. Ensure Habana drivers are installed and accessible."
)
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
......@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return major, minor
def get_compiler_backend() -> str:
if hasattr(torch, "hpu") and torch.hpu.is_available():
return "hpu_backend"
return "inductor"
sglang_lib = Library("sglang", "FRAGMENT") # noqa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment