"docs/source/vscode:/vscode.git/clone" did not exist on "35e25914ad6d9c99ef7ee1396ba00684f4015aa7"
Unverified Commit 37ee906f authored by Qun Yang's avatar Qun Yang Committed by GitHub
Browse files

Add more support for intel Gaudi accelerators (#2357)

parent 34b364e0
import argparse
import dataclasses
import sglang as sgl
from sglang.srt.server_args import ServerArgs
def main():
def main(
server_args: ServerArgs,
):
# Sample prompts.
prompts = [
"Hello, my name is",
......@@ -13,7 +19,7 @@ def main():
sampling_params = {"temperature": 0.8, "top_p": 0.95}
# Create an LLM.
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
llm = sgl.Engine(**dataclasses.asdict(server_args))
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
......@@ -25,4 +31,8 @@ def main():
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
main(server_args)
......@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
sampled_index = torch.multinomial(probs_sort, num_samples=1)
# int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids
......@@ -993,7 +993,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
......@@ -1055,7 +1055,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
......@@ -1130,7 +1130,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)
......
......@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
......@@ -73,7 +74,7 @@ class TpModelWorkerClient:
# Launch threads
self.input_queue = Queue()
self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
......@@ -97,7 +98,7 @@ class TpModelWorkerClient:
def forward_thread_func(self):
try:
with torch.cuda.stream(self.forward_stream):
with torch.get_device_module(self.device).stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
......@@ -122,7 +123,7 @@ class TpModelWorkerClient:
# Create event
self.launch_done = threading.Event()
copy_done = torch.cuda.Event()
copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
......@@ -190,7 +191,7 @@ class TpModelWorkerClient:
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch.cuda.current_stream().synchronize()
torch.get_device_module(self.device).current_stream().synchronize()
# Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
......
......@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
import torch
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_compiler_backend
logger = logging.getLogger(__name__)
......@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
return select_index.to(self.device, non_blocking=True)
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
else:
......@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_1[loc] = src_1.to(dtype).view(store_dtype)
dst_2[loc] = src_2.to(dtype).view(store_dtype)
......
......@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
@torch.compile
@torch.compile(backend=get_compiler_backend())
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
......
......@@ -25,6 +25,7 @@ import torch
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
get_nvgpu_memory_capacity,
is_flashinfer_available,
is_hip,
......@@ -158,6 +159,8 @@ class ServerArgs:
gpu_mem = get_amdgpu_memory_capacity()
elif torch.cuda.is_available():
gpu_mem = get_nvgpu_memory_capacity()
elif self.device == "hpu":
gpu_mem = get_hpu_memory_capacity()
else:
# GPU memory is not known yet or no GPU is available.
gpu_mem = None
......@@ -194,6 +197,10 @@ class ServerArgs:
self.cuda_graph_max_bs = 160
# Choose kernel backends
if self.device == "hpu":
self.attention_backend = "torch_native"
self.sampling_backend = "pytorch"
if self.attention_backend is None:
self.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
......
......@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory
elif device == "hpu":
num_gpus = torch.hpu.device_count()
assert gpu_id < num_gpus
if torch.hpu.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
"which may cause useless memory allocation for torch HPU context.",
)
free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device(device, gpu_id)
......@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
)
def get_hpu_memory_capacity():
try:
# Run hl-smi and capture the output
result = subprocess.run(
["hl-smi --query | grep 'Total'"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")
# Parse the output to extract memory values in MiB
memory_values = [
float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
]
if not memory_values:
raise ValueError("No GPU memory values found.")
# Return the minimum memory value
return min(memory_values)
except FileNotFoundError:
raise RuntimeError(
"hl-smi not found. Ensure Habana drivers are installed and accessible."
)
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
......@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return major, minor
def get_compiler_backend() -> str:
if hasattr(torch, "hpu") and torch.hpu.is_available():
return "hpu_backend"
return "inductor"
sglang_lib = Library("sglang", "FRAGMENT") # noqa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment