"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e7e3749498921121d6e710cb7524f48617cec233"
Unverified Commit 37ee906f authored by Qun Yang's avatar Qun Yang Committed by GitHub
Browse files

Add more support for intel Gaudi accelerators (#2357)

parent 34b364e0
import argparse
import dataclasses
import sglang as sgl import sglang as sgl
from sglang.srt.server_args import ServerArgs
def main(): def main(
server_args: ServerArgs,
):
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
...@@ -13,7 +19,7 @@ def main(): ...@@ -13,7 +19,7 @@ def main():
sampling_params = {"temperature": 0.8, "top_p": 0.95} sampling_params = {"temperature": 0.8, "top_p": 0.95}
# Create an LLM. # Create an LLM.
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") llm = sgl.Engine(**dataclasses.asdict(server_args))
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
...@@ -25,4 +31,8 @@ def main(): ...@@ -25,4 +31,8 @@ def main():
# The __main__ condition is necessary here because we use "spawn" to create subprocesses # The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine # Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
main(server_args)
...@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch( ...@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
sampled_index = torch.multinomial(probs_sort, num_samples=1) sampled_index = torch.multinomial(probs_sort, num_samples=1)
# int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids return batch_next_token_ids
...@@ -993,7 +993,7 @@ class Scheduler: ...@@ -993,7 +993,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize() torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
...@@ -1055,7 +1055,7 @@ class Scheduler: ...@@ -1055,7 +1055,7 @@ class Scheduler:
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize() torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model else: # embedding or reward model
...@@ -1130,7 +1130,7 @@ class Scheduler: ...@@ -1130,7 +1130,7 @@ class Scheduler:
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize() torch.get_device_module(self.device).current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs) self.stream_output(batch.reqs)
......
...@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import ( ...@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@torch.compile(dynamic=True) @torch.compile(dynamic=True, backend=get_compiler_backend())
def resolve_future_token_ids(input_ids, future_token_ids_map): def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where( input_ids[:] = torch.where(
input_ids < 0, input_ids < 0,
...@@ -73,7 +74,7 @@ class TpModelWorkerClient: ...@@ -73,7 +74,7 @@ class TpModelWorkerClient:
# Launch threads # Launch threads
self.input_queue = Queue() self.input_queue = Queue()
self.output_queue = Queue() self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream() self.forward_stream = torch.get_device_module(self.device).Stream()
self.forward_thread = threading.Thread( self.forward_thread = threading.Thread(
target=self.forward_thread_func, target=self.forward_thread_func,
) )
...@@ -97,7 +98,7 @@ class TpModelWorkerClient: ...@@ -97,7 +98,7 @@ class TpModelWorkerClient:
def forward_thread_func(self): def forward_thread_func(self):
try: try:
with torch.cuda.stream(self.forward_stream): with torch.get_device_module(self.device).stream(self.forward_stream):
self.forward_thread_func_() self.forward_thread_func_()
except Exception: except Exception:
traceback = get_exception_traceback() traceback = get_exception_traceback()
...@@ -122,7 +123,7 @@ class TpModelWorkerClient: ...@@ -122,7 +123,7 @@ class TpModelWorkerClient:
# Create event # Create event
self.launch_done = threading.Event() self.launch_done = threading.Event()
copy_done = torch.cuda.Event() copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input # Resolve future tokens in the input
input_ids = model_worker_batch.input_ids input_ids = model_worker_batch.input_ids
...@@ -190,7 +191,7 @@ class TpModelWorkerClient: ...@@ -190,7 +191,7 @@ class TpModelWorkerClient:
) )
# A cuda stream sync here to avoid the cuda illegal memory access error. # A cuda stream sync here to avoid the cuda illegal memory access error.
torch.cuda.current_stream().synchronize() torch.get_device_module(self.device).current_stream().synchronize()
# Push a new batch to the queue # Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
......
...@@ -27,6 +27,7 @@ from typing import List, Tuple, Union ...@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
import torch import torch
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_compiler_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -129,6 +130,9 @@ class BaseTokenToKVPool: ...@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
return select_index.to(self.device, non_blocking=True) return select_index.to(self.device, non_blocking=True)
def free(self, free_index: torch.Tensor): def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group: if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index.cpu())) self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
else: else:
...@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
# This compiled version is slower in the unit test # This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True) @torch.compile(dynamic=True, backend=get_compiler_backend())
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_1[loc] = src_1.to(dtype).view(store_dtype) dst_1[loc] = src_1.to(dtype).view(store_dtype)
dst_2[loc] = src_2.to(dtype).view(store_dtype) dst_2[loc] = src_2.to(dtype).view(store_dtype)
......
...@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import get_compiler_backend, set_weight_attrs
@torch.compile @torch.compile(backend=get_compiler_backend())
def layer_norm_func(hidden_states, weight, variance_epsilon): def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
......
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import ( from sglang.srt.utils import (
get_amdgpu_memory_capacity, get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
get_nvgpu_memory_capacity, get_nvgpu_memory_capacity,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
...@@ -158,6 +159,8 @@ class ServerArgs: ...@@ -158,6 +159,8 @@ class ServerArgs:
gpu_mem = get_amdgpu_memory_capacity() gpu_mem = get_amdgpu_memory_capacity()
elif torch.cuda.is_available(): elif torch.cuda.is_available():
gpu_mem = get_nvgpu_memory_capacity() gpu_mem = get_nvgpu_memory_capacity()
elif self.device == "hpu":
gpu_mem = get_hpu_memory_capacity()
else: else:
# GPU memory is not known yet or no GPU is available. # GPU memory is not known yet or no GPU is available.
gpu_mem = None gpu_mem = None
...@@ -194,6 +197,10 @@ class ServerArgs: ...@@ -194,6 +197,10 @@ class ServerArgs:
self.cuda_graph_max_bs = 160 self.cuda_graph_max_bs = 160
# Choose kernel backends # Choose kernel backends
if self.device == "hpu":
self.attention_backend = "torch_native"
self.sampling_backend = "pytorch"
if self.attention_backend is None: if self.attention_backend is None:
self.attention_backend = ( self.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton" "flashinfer" if is_flashinfer_available() else "triton"
......
...@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): ...@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory free_gpu_memory = total_gpu_memory - used_memory
elif device == "hpu":
num_gpus = torch.hpu.device_count()
assert gpu_id < num_gpus
if torch.hpu.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
"which may cause useless memory allocation for torch HPU context.",
)
free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
if distributed: if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device(device, gpu_id) torch.device(device, gpu_id)
...@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity(): ...@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
) )
def get_hpu_memory_capacity():
try:
# Run hl-smi and capture the output
result = subprocess.run(
["hl-smi --query | grep 'Total'"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")
# Parse the output to extract memory values in MiB
memory_values = [
float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
]
if not memory_values:
raise ValueError("No GPU memory values found.")
# Return the minimum memory value
return min(memory_values)
except FileNotFoundError:
raise RuntimeError(
"hl-smi not found. Ensure Habana drivers are installed and accessible."
)
# Copy from pytorch and OpenRLHF to allow creating multiple main groups. # Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py # https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
...@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: ...@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return major, minor return major, minor
def get_compiler_backend() -> str:
if hasattr(torch, "hpu") and torch.hpu.is_available():
return "hpu_backend"
return "inductor"
sglang_lib = Library("sglang", "FRAGMENT") # noqa sglang_lib = Library("sglang", "FRAGMENT") # noqa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment