Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37ee906f
Unverified
Commit
37ee906f
authored
Dec 06, 2024
by
Qun Yang
Committed by
GitHub
Dec 06, 2024
Browse files
Add more support for intel Gaudi accelerators (#2357)
parent
34b364e0
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
88 additions
and
14 deletions
+88
-14
examples/runtime/engine/offline_batch_inference.py
examples/runtime/engine/offline_batch_inference.py
+13
-3
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+6
-5
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+5
-1
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+2
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+50
-0
No files found.
examples/runtime/engine/offline_batch_inference.py
View file @
37ee906f
import
argparse
import
dataclasses
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.srt.server_args
import
ServerArgs
def
main
():
def
main
(
server_args
:
ServerArgs
,
):
# Sample prompts.
# Sample prompts.
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
...
@@ -13,7 +19,7 @@ def main():
...
@@ -13,7 +19,7 @@ def main():
sampling_params
=
{
"temperature"
:
0.8
,
"top_p"
:
0.95
}
sampling_params
=
{
"temperature"
:
0.8
,
"top_p"
:
0.95
}
# Create an LLM.
# Create an LLM.
llm
=
sgl
.
Engine
(
model_path
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
)
llm
=
sgl
.
Engine
(
**
dataclasses
.
asdict
(
server_args
)
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
# Print the outputs.
...
@@ -25,4 +31,8 @@ def main():
...
@@ -25,4 +31,8 @@ def main():
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
main
(
server_args
)
python/sglang/srt/layers/sampler.py
View file @
37ee906f
...
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
...
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
# int32 range is enough to represent the token ids
probs_idx
=
probs_idx
.
to
(
torch
.
int32
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
return
batch_next_token_ids
return
batch_next_token_ids
python/sglang/srt/managers/scheduler.py
View file @
37ee906f
...
@@ -993,7 +993,7 @@ class Scheduler:
...
@@ -993,7 +993,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
@@ -1055,7 +1055,7 @@ class Scheduler:
...
@@ -1055,7 +1055,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
else
:
# embedding or reward model
...
@@ -1130,7 +1130,7 @@ class Scheduler:
...
@@ -1130,7 +1130,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
37ee906f
...
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
...
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
def
resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
def
resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
input_ids
<
0
,
...
@@ -73,7 +74,7 @@ class TpModelWorkerClient:
...
@@ -73,7 +74,7 @@ class TpModelWorkerClient:
# Launch threads
# Launch threads
self
.
input_queue
=
Queue
()
self
.
input_queue
=
Queue
()
self
.
output_queue
=
Queue
()
self
.
output_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_stream
=
torch
.
get_device_module
(
self
.
device
)
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
target
=
self
.
forward_thread_func
,
)
)
...
@@ -97,7 +98,7 @@ class TpModelWorkerClient:
...
@@ -97,7 +98,7 @@ class TpModelWorkerClient:
def
forward_thread_func
(
self
):
def
forward_thread_func
(
self
):
try
:
try
:
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
with
torch
.
get_device_module
(
self
.
device
)
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
self
.
forward_thread_func_
()
except
Exception
:
except
Exception
:
traceback
=
get_exception_traceback
()
traceback
=
get_exception_traceback
()
...
@@ -122,7 +123,7 @@ class TpModelWorkerClient:
...
@@ -122,7 +123,7 @@ class TpModelWorkerClient:
# Create event
# Create event
self
.
launch_done
=
threading
.
Event
()
self
.
launch_done
=
threading
.
Event
()
copy_done
=
torch
.
cuda
.
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
)
.
Event
()
# Resolve future tokens in the input
# Resolve future tokens in the input
input_ids
=
model_worker_batch
.
input_ids
input_ids
=
model_worker_batch
.
input_ids
...
@@ -190,7 +191,7 @@ class TpModelWorkerClient:
...
@@ -190,7 +191,7 @@ class TpModelWorkerClient:
)
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
# Push a new batch to the queue
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
37ee906f
...
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
...
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
import
torch
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_compiler_backend
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
...
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
return
select_index
.
to
(
self
.
device
,
non_blocking
=
True
)
return
select_index
.
to
(
self
.
device
,
non_blocking
=
True
)
def
free
(
self
,
free_index
:
torch
.
Tensor
):
def
free
(
self
,
free_index
:
torch
.
Tensor
):
if
free_index
.
numel
()
==
0
:
return
if
self
.
is_not_in_free_group
:
if
self
.
is_not_in_free_group
:
self
.
free_slots
=
torch
.
concat
((
self
.
free_slots
,
free_index
.
cpu
()))
self
.
free_slots
=
torch
.
concat
((
self
.
free_slots
,
free_index
.
cpu
()))
else
:
else
:
...
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
# This compiled version is slower in the unit test
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
def
copy_two_array
(
loc
,
dst_1
,
src_1
,
dst_2
,
src_2
,
dtype
,
store_dtype
):
def
copy_two_array
(
loc
,
dst_1
,
src_1
,
dst_2
,
src_2
,
dtype
,
store_dtype
):
dst_1
[
loc
]
=
src_1
.
to
(
dtype
).
view
(
store_dtype
)
dst_1
[
loc
]
=
src_1
.
to
(
dtype
).
view
(
store_dtype
)
dst_2
[
loc
]
=
src_2
.
to
(
dtype
).
view
(
store_dtype
)
dst_2
[
loc
]
=
src_2
.
to
(
dtype
).
view
(
store_dtype
)
...
...
python/sglang/srt/models/commandr.py
View file @
37ee906f
...
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
get_compiler_backend
,
set_weight_attrs
@
torch
.
compile
@
torch
.
compile
(
backend
=
get_compiler_backend
())
def
layer_norm_func
(
hidden_states
,
weight
,
variance_epsilon
):
def
layer_norm_func
(
hidden_states
,
weight
,
variance_epsilon
):
input_dtype
=
hidden_states
.
dtype
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
...
...
python/sglang/srt/server_args.py
View file @
37ee906f
...
@@ -25,6 +25,7 @@ import torch
...
@@ -25,6 +25,7 @@ import torch
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_amdgpu_memory_capacity
,
get_amdgpu_memory_capacity
,
get_hpu_memory_capacity
,
get_nvgpu_memory_capacity
,
get_nvgpu_memory_capacity
,
is_flashinfer_available
,
is_flashinfer_available
,
is_hip
,
is_hip
,
...
@@ -158,6 +159,8 @@ class ServerArgs:
...
@@ -158,6 +159,8 @@ class ServerArgs:
gpu_mem
=
get_amdgpu_memory_capacity
()
gpu_mem
=
get_amdgpu_memory_capacity
()
elif
torch
.
cuda
.
is_available
():
elif
torch
.
cuda
.
is_available
():
gpu_mem
=
get_nvgpu_memory_capacity
()
gpu_mem
=
get_nvgpu_memory_capacity
()
elif
self
.
device
==
"hpu"
:
gpu_mem
=
get_hpu_memory_capacity
()
else
:
else
:
# GPU memory is not known yet or no GPU is available.
# GPU memory is not known yet or no GPU is available.
gpu_mem
=
None
gpu_mem
=
None
...
@@ -194,6 +197,10 @@ class ServerArgs:
...
@@ -194,6 +197,10 @@ class ServerArgs:
self
.
cuda_graph_max_bs
=
160
self
.
cuda_graph_max_bs
=
160
# Choose kernel backends
# Choose kernel backends
if
self
.
device
==
"hpu"
:
self
.
attention_backend
=
"torch_native"
self
.
sampling_backend
=
"pytorch"
if
self
.
attention_backend
is
None
:
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
(
self
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
...
...
python/sglang/srt/utils.py
View file @
37ee906f
...
@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
...
@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
total_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
gpu_id
).
total_memory
total_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
gpu_id
).
total_memory
free_gpu_memory
=
total_gpu_memory
-
used_memory
free_gpu_memory
=
total_gpu_memory
-
used_memory
elif
device
==
"hpu"
:
num_gpus
=
torch
.
hpu
.
device_count
()
assert
gpu_id
<
num_gpus
if
torch
.
hpu
.
current_device
()
!=
gpu_id
:
print
(
f
"WARNING: current device is not
{
gpu_id
}
, but
{
torch
.
hpu
.
current_device
()
}
, "
,
"which may cause useless memory allocation for torch HPU context."
,
)
free_gpu_memory
,
total_gpu_memory
=
torch
.
hpu
.
mem_get_info
()
if
distributed
:
if
distributed
:
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
torch
.
device
(
device
,
gpu_id
)
torch
.
device
(
device
,
gpu_id
)
...
@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
...
@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
)
)
def
get_hpu_memory_capacity
():
try
:
# Run hl-smi and capture the output
result
=
subprocess
.
run
(
[
"hl-smi --query | grep 'Total'"
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
True
,
text
=
True
,
)
if
result
.
returncode
!=
0
:
raise
RuntimeError
(
f
"hl-smi error:
{
result
.
stderr
.
strip
()
}
"
)
# Parse the output to extract memory values in MiB
memory_values
=
[
float
(
mem
.
split
(
" "
)[
-
2
])
for
mem
in
result
.
stdout
.
strip
().
split
(
"
\n
"
)
]
if
not
memory_values
:
raise
ValueError
(
"No GPU memory values found."
)
# Return the minimum memory value
return
min
(
memory_values
)
except
FileNotFoundError
:
raise
RuntimeError
(
"hl-smi not found. Ensure Habana drivers are installed and accessible."
)
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
...
@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
...
@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return
major
,
minor
return
major
,
minor
def
get_compiler_backend
()
->
str
:
if
hasattr
(
torch
,
"hpu"
)
and
torch
.
hpu
.
is_available
():
return
"hpu_backend"
return
"inductor"
sglang_lib
=
Library
(
"sglang"
,
"FRAGMENT"
)
# noqa
sglang_lib
=
Library
(
"sglang"
,
"FRAGMENT"
)
# noqa
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment