Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37ee906f
Unverified
Commit
37ee906f
authored
Dec 06, 2024
by
Qun Yang
Committed by
GitHub
Dec 06, 2024
Browse files
Add more support for intel Gaudi accelerators (#2357)
parent
34b364e0
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
88 additions
and
14 deletions
+88
-14
examples/runtime/engine/offline_batch_inference.py
examples/runtime/engine/offline_batch_inference.py
+13
-3
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+6
-5
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+5
-1
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+2
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+50
-0
No files found.
examples/runtime/engine/offline_batch_inference.py
View file @
37ee906f
import
argparse
import
dataclasses
import
sglang
as
sgl
from
sglang.srt.server_args
import
ServerArgs
def
main
():
def
main
(
server_args
:
ServerArgs
,
):
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
...
...
@@ -13,7 +19,7 @@ def main():
sampling_params
=
{
"temperature"
:
0.8
,
"top_p"
:
0.95
}
# Create an LLM.
llm
=
sgl
.
Engine
(
model_path
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
)
llm
=
sgl
.
Engine
(
**
dataclasses
.
asdict
(
server_args
)
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
...
...
@@ -25,4 +31,8 @@ def main():
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if
__name__
==
"__main__"
:
main
()
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
main
(
server_args
)
python/sglang/srt/layers/sampler.py
View file @
37ee906f
...
...
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
# int32 range is enough to represent the token ids
probs_idx
=
probs_idx
.
to
(
torch
.
int32
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
return
batch_next_token_ids
python/sglang/srt/managers/scheduler.py
View file @
37ee906f
...
...
@@ -993,7 +993,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
...
@@ -1055,7 +1055,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
...
...
@@ -1130,7 +1130,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
37ee906f
...
...
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
def
resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
...
...
@@ -73,7 +74,7 @@ class TpModelWorkerClient:
# Launch threads
self
.
input_queue
=
Queue
()
self
.
output_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_stream
=
torch
.
get_device_module
(
self
.
device
)
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
)
...
...
@@ -97,7 +98,7 @@ class TpModelWorkerClient:
def
forward_thread_func
(
self
):
try
:
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
with
torch
.
get_device_module
(
self
.
device
)
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
except
Exception
:
traceback
=
get_exception_traceback
()
...
...
@@ -122,7 +123,7 @@ class TpModelWorkerClient:
# Create event
self
.
launch_done
=
threading
.
Event
()
copy_done
=
torch
.
cuda
.
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
)
.
Event
()
# Resolve future tokens in the input
input_ids
=
model_worker_batch
.
input_ids
...
...
@@ -190,7 +191,7 @@ class TpModelWorkerClient:
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
37ee906f
...
...
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_compiler_backend
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
return
select_index
.
to
(
self
.
device
,
non_blocking
=
True
)
def
free
(
self
,
free_index
:
torch
.
Tensor
):
if
free_index
.
numel
()
==
0
:
return
if
self
.
is_not_in_free_group
:
self
.
free_slots
=
torch
.
concat
((
self
.
free_slots
,
free_index
.
cpu
()))
else
:
...
...
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
def
copy_two_array
(
loc
,
dst_1
,
src_1
,
dst_2
,
src_2
,
dtype
,
store_dtype
):
dst_1
[
loc
]
=
src_1
.
to
(
dtype
).
view
(
store_dtype
)
dst_2
[
loc
]
=
src_2
.
to
(
dtype
).
view
(
store_dtype
)
...
...
python/sglang/srt/models/commandr.py
View file @
37ee906f
...
...
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
get_compiler_backend
,
set_weight_attrs
@
torch
.
compile
@
torch
.
compile
(
backend
=
get_compiler_backend
())
def
layer_norm_func
(
hidden_states
,
weight
,
variance_epsilon
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
...
...
python/sglang/srt/server_args.py
View file @
37ee906f
...
...
@@ -25,6 +25,7 @@ import torch
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.utils
import
(
get_amdgpu_memory_capacity
,
get_hpu_memory_capacity
,
get_nvgpu_memory_capacity
,
is_flashinfer_available
,
is_hip
,
...
...
@@ -158,6 +159,8 @@ class ServerArgs:
gpu_mem
=
get_amdgpu_memory_capacity
()
elif
torch
.
cuda
.
is_available
():
gpu_mem
=
get_nvgpu_memory_capacity
()
elif
self
.
device
==
"hpu"
:
gpu_mem
=
get_hpu_memory_capacity
()
else
:
# GPU memory is not known yet or no GPU is available.
gpu_mem
=
None
...
...
@@ -194,6 +197,10 @@ class ServerArgs:
self
.
cuda_graph_max_bs
=
160
# Choose kernel backends
if
self
.
device
==
"hpu"
:
self
.
attention_backend
=
"torch_native"
self
.
sampling_backend
=
"pytorch"
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
...
...
python/sglang/srt/utils.py
View file @
37ee906f
...
...
@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
total_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
gpu_id
).
total_memory
free_gpu_memory
=
total_gpu_memory
-
used_memory
elif
device
==
"hpu"
:
num_gpus
=
torch
.
hpu
.
device_count
()
assert
gpu_id
<
num_gpus
if
torch
.
hpu
.
current_device
()
!=
gpu_id
:
print
(
f
"WARNING: current device is not
{
gpu_id
}
, but
{
torch
.
hpu
.
current_device
()
}
, "
,
"which may cause useless memory allocation for torch HPU context."
,
)
free_gpu_memory
,
total_gpu_memory
=
torch
.
hpu
.
mem_get_info
()
if
distributed
:
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
torch
.
device
(
device
,
gpu_id
)
...
...
@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
)
def
get_hpu_memory_capacity
():
try
:
# Run hl-smi and capture the output
result
=
subprocess
.
run
(
[
"hl-smi --query | grep 'Total'"
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
True
,
text
=
True
,
)
if
result
.
returncode
!=
0
:
raise
RuntimeError
(
f
"hl-smi error:
{
result
.
stderr
.
strip
()
}
"
)
# Parse the output to extract memory values in MiB
memory_values
=
[
float
(
mem
.
split
(
" "
)[
-
2
])
for
mem
in
result
.
stdout
.
strip
().
split
(
"
\n
"
)
]
if
not
memory_values
:
raise
ValueError
(
"No GPU memory values found."
)
# Return the minimum memory value
return
min
(
memory_values
)
except
FileNotFoundError
:
raise
RuntimeError
(
"hl-smi not found. Ensure Habana drivers are installed and accessible."
)
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
...
...
@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return
major
,
minor
def
get_compiler_backend
()
->
str
:
if
hasattr
(
torch
,
"hpu"
)
and
torch
.
hpu
.
is_available
():
return
"hpu_backend"
return
"inductor"
sglang_lib
=
Library
(
"sglang"
,
"FRAGMENT"
)
# noqa
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment