Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37ee906f
"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e7e3749498921121d6e710cb7524f48617cec233"
Unverified
Commit
37ee906f
authored
Dec 06, 2024
by
Qun Yang
Committed by
GitHub
Dec 06, 2024
Browse files
Add more support for intel Gaudi accelerators (#2357)
parent
34b364e0
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
88 additions
and
14 deletions
+88
-14
examples/runtime/engine/offline_batch_inference.py
examples/runtime/engine/offline_batch_inference.py
+13
-3
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+6
-5
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+5
-1
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+2
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+50
-0
No files found.
examples/runtime/engine/offline_batch_inference.py
View file @
37ee906f
import
argparse
import
dataclasses
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.srt.server_args
import
ServerArgs
def
main
():
def
main
(
server_args
:
ServerArgs
,
):
# Sample prompts.
# Sample prompts.
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
...
@@ -13,7 +19,7 @@ def main():
...
@@ -13,7 +19,7 @@ def main():
sampling_params
=
{
"temperature"
:
0.8
,
"top_p"
:
0.95
}
sampling_params
=
{
"temperature"
:
0.8
,
"top_p"
:
0.95
}
# Create an LLM.
# Create an LLM.
llm
=
sgl
.
Engine
(
model_path
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
)
llm
=
sgl
.
Engine
(
**
dataclasses
.
asdict
(
server_args
)
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
# Print the outputs.
...
@@ -25,4 +31,8 @@ def main():
...
@@ -25,4 +31,8 @@ def main():
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
main
(
server_args
)
python/sglang/srt/layers/sampler.py
View file @
37ee906f
...
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
...
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
# int32 range is enough to represent the token ids
probs_idx
=
probs_idx
.
to
(
torch
.
int32
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
return
batch_next_token_ids
return
batch_next_token_ids
python/sglang/srt/managers/scheduler.py
View file @
37ee906f
...
@@ -993,7 +993,7 @@ class Scheduler:
...
@@ -993,7 +993,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
@@ -1055,7 +1055,7 @@ class Scheduler:
...
@@ -1055,7 +1055,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
else
:
# embedding or reward model
...
@@ -1130,7 +1130,7 @@ class Scheduler:
...
@@ -1130,7 +1130,7 @@ class Scheduler:
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
37ee906f
...
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
...
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
def
resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
def
resolve_future_token_ids
(
input_ids
,
future_token_ids_map
):
input_ids
[:]
=
torch
.
where
(
input_ids
[:]
=
torch
.
where
(
input_ids
<
0
,
input_ids
<
0
,
...
@@ -73,7 +74,7 @@ class TpModelWorkerClient:
...
@@ -73,7 +74,7 @@ class TpModelWorkerClient:
# Launch threads
# Launch threads
self
.
input_queue
=
Queue
()
self
.
input_queue
=
Queue
()
self
.
output_queue
=
Queue
()
self
.
output_queue
=
Queue
()
self
.
forward_stream
=
torch
.
cuda
.
Stream
()
self
.
forward_stream
=
torch
.
get_device_module
(
self
.
device
)
.
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
self
.
forward_thread
=
threading
.
Thread
(
target
=
self
.
forward_thread_func
,
target
=
self
.
forward_thread_func
,
)
)
...
@@ -97,7 +98,7 @@ class TpModelWorkerClient:
...
@@ -97,7 +98,7 @@ class TpModelWorkerClient:
def
forward_thread_func
(
self
):
def
forward_thread_func
(
self
):
try
:
try
:
with
torch
.
cuda
.
stream
(
self
.
forward_stream
):
with
torch
.
get_device_module
(
self
.
device
)
.
stream
(
self
.
forward_stream
):
self
.
forward_thread_func_
()
self
.
forward_thread_func_
()
except
Exception
:
except
Exception
:
traceback
=
get_exception_traceback
()
traceback
=
get_exception_traceback
()
...
@@ -122,7 +123,7 @@ class TpModelWorkerClient:
...
@@ -122,7 +123,7 @@ class TpModelWorkerClient:
# Create event
# Create event
self
.
launch_done
=
threading
.
Event
()
self
.
launch_done
=
threading
.
Event
()
copy_done
=
torch
.
cuda
.
Event
()
copy_done
=
torch
.
get_device_module
(
self
.
device
)
.
Event
()
# Resolve future tokens in the input
# Resolve future tokens in the input
input_ids
=
model_worker_batch
.
input_ids
input_ids
=
model_worker_batch
.
input_ids
...
@@ -190,7 +191,7 @@ class TpModelWorkerClient:
...
@@ -190,7 +191,7 @@ class TpModelWorkerClient:
)
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch
.
cuda
.
current_stream
().
synchronize
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
().
synchronize
()
# Push a new batch to the queue
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
37ee906f
...
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
...
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
import
torch
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_compiler_backend
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
...
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
return
select_index
.
to
(
self
.
device
,
non_blocking
=
True
)
return
select_index
.
to
(
self
.
device
,
non_blocking
=
True
)
def
free
(
self
,
free_index
:
torch
.
Tensor
):
def
free
(
self
,
free_index
:
torch
.
Tensor
):
if
free_index
.
numel
()
==
0
:
return
if
self
.
is_not_in_free_group
:
if
self
.
is_not_in_free_group
:
self
.
free_slots
=
torch
.
concat
((
self
.
free_slots
,
free_index
.
cpu
()))
self
.
free_slots
=
torch
.
concat
((
self
.
free_slots
,
free_index
.
cpu
()))
else
:
else
:
...
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
# This compiled version is slower in the unit test
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
def
copy_two_array
(
loc
,
dst_1
,
src_1
,
dst_2
,
src_2
,
dtype
,
store_dtype
):
def
copy_two_array
(
loc
,
dst_1
,
src_1
,
dst_2
,
src_2
,
dtype
,
store_dtype
):
dst_1
[
loc
]
=
src_1
.
to
(
dtype
).
view
(
store_dtype
)
dst_1
[
loc
]
=
src_1
.
to
(
dtype
).
view
(
store_dtype
)
dst_2
[
loc
]
=
src_2
.
to
(
dtype
).
view
(
store_dtype
)
dst_2
[
loc
]
=
src_2
.
to
(
dtype
).
view
(
store_dtype
)
...
...
python/sglang/srt/models/commandr.py
View file @
37ee906f
...
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
get_compiler_backend
,
set_weight_attrs
@
torch
.
compile
@
torch
.
compile
(
backend
=
get_compiler_backend
())
def
layer_norm_func
(
hidden_states
,
weight
,
variance_epsilon
):
def
layer_norm_func
(
hidden_states
,
weight
,
variance_epsilon
):
input_dtype
=
hidden_states
.
dtype
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
...
...
python/sglang/srt/server_args.py
View file @
37ee906f
...
@@ -25,6 +25,7 @@ import torch
...
@@ -25,6 +25,7 @@ import torch
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_amdgpu_memory_capacity
,
get_amdgpu_memory_capacity
,
get_hpu_memory_capacity
,
get_nvgpu_memory_capacity
,
get_nvgpu_memory_capacity
,
is_flashinfer_available
,
is_flashinfer_available
,
is_hip
,
is_hip
,
...
@@ -158,6 +159,8 @@ class ServerArgs:
...
@@ -158,6 +159,8 @@ class ServerArgs:
gpu_mem
=
get_amdgpu_memory_capacity
()
gpu_mem
=
get_amdgpu_memory_capacity
()
elif
torch
.
cuda
.
is_available
():
elif
torch
.
cuda
.
is_available
():
gpu_mem
=
get_nvgpu_memory_capacity
()
gpu_mem
=
get_nvgpu_memory_capacity
()
elif
self
.
device
==
"hpu"
:
gpu_mem
=
get_hpu_memory_capacity
()
else
:
else
:
# GPU memory is not known yet or no GPU is available.
# GPU memory is not known yet or no GPU is available.
gpu_mem
=
None
gpu_mem
=
None
...
@@ -194,6 +197,10 @@ class ServerArgs:
...
@@ -194,6 +197,10 @@ class ServerArgs:
self
.
cuda_graph_max_bs
=
160
self
.
cuda_graph_max_bs
=
160
# Choose kernel backends
# Choose kernel backends
if
self
.
device
==
"hpu"
:
self
.
attention_backend
=
"torch_native"
self
.
sampling_backend
=
"pytorch"
if
self
.
attention_backend
is
None
:
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
(
self
.
attention_backend
=
(
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
"flashinfer"
if
is_flashinfer_available
()
else
"triton"
...
...
python/sglang/srt/utils.py
View file @
37ee906f
...
@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
...
@@ -201,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
total_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
gpu_id
).
total_memory
total_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
gpu_id
).
total_memory
free_gpu_memory
=
total_gpu_memory
-
used_memory
free_gpu_memory
=
total_gpu_memory
-
used_memory
elif
device
==
"hpu"
:
num_gpus
=
torch
.
hpu
.
device_count
()
assert
gpu_id
<
num_gpus
if
torch
.
hpu
.
current_device
()
!=
gpu_id
:
print
(
f
"WARNING: current device is not
{
gpu_id
}
, but
{
torch
.
hpu
.
current_device
()
}
, "
,
"which may cause useless memory allocation for torch HPU context."
,
)
free_gpu_memory
,
total_gpu_memory
=
torch
.
hpu
.
mem_get_info
()
if
distributed
:
if
distributed
:
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
torch
.
device
(
device
,
gpu_id
)
torch
.
device
(
device
,
gpu_id
)
...
@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
...
@@ -939,6 +951,37 @@ def get_nvgpu_memory_capacity():
)
)
def
get_hpu_memory_capacity
():
try
:
# Run hl-smi and capture the output
result
=
subprocess
.
run
(
[
"hl-smi --query | grep 'Total'"
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
shell
=
True
,
text
=
True
,
)
if
result
.
returncode
!=
0
:
raise
RuntimeError
(
f
"hl-smi error:
{
result
.
stderr
.
strip
()
}
"
)
# Parse the output to extract memory values in MiB
memory_values
=
[
float
(
mem
.
split
(
" "
)[
-
2
])
for
mem
in
result
.
stdout
.
strip
().
split
(
"
\n
"
)
]
if
not
memory_values
:
raise
ValueError
(
"No GPU memory values found."
)
# Return the minimum memory value
return
min
(
memory_values
)
except
FileNotFoundError
:
raise
RuntimeError
(
"hl-smi not found. Ensure Habana drivers are installed and accessible."
)
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
...
@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
...
@@ -1062,6 +1105,13 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return
major
,
minor
return
major
,
minor
def
get_compiler_backend
()
->
str
:
if
hasattr
(
torch
,
"hpu"
)
and
torch
.
hpu
.
is_available
():
return
"hpu_backend"
return
"inductor"
sglang_lib
=
Library
(
"sglang"
,
"FRAGMENT"
)
# noqa
sglang_lib
=
Library
(
"sglang"
,
"FRAGMENT"
)
# noqa
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment