Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8275049c
"vscode:/vscode.git/clone" did not exist on "abd5dcbbf179e08525b4903d704d0a6dbaee210b"
Unverified
Commit
8275049c
authored
Oct 11, 2024
by
Zhang, Liangang
Committed by
GitHub
Oct 11, 2024
Browse files
Add device support (#1607)
parent
5476ccad
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
96 additions
and
52 deletions
+96
-52
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+39
-34
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+19
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+26
-11
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
8275049c
...
...
@@ -423,6 +423,9 @@ class ScheduleBatch:
# Stream
has_stream
:
bool
=
False
# device
device
:
str
=
"cuda"
# Has regex
has_regex
:
bool
=
False
...
...
@@ -439,6 +442,7 @@ class ScheduleBatch:
tree_cache
=
tree_cache
,
return_logprob
=
return_logprob
,
has_stream
=
has_stream
,
device
=
req_to_token_pool
.
device
,
has_regex
=
has_regex
,
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
8275049c
...
...
@@ -81,10 +81,11 @@ class ModelRunner:
# Parse args
self
.
model_config
=
model_config
self
.
mem_fraction_static
=
mem_fraction_static
self
.
device
=
server_args
.
device
self
.
gpu_id
=
gpu_id
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
nccl
_port
=
nccl_port
self
.
dist
_port
=
nccl_port
self
.
server_args
=
server_args
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
.
hf_config
.
architectures
...
...
@@ -132,39 +133,45 @@ class ModelRunner:
server_args
.
max_running_requests
,
server_args
.
max_total_tokens
,
)
self
.
init_cublas
()
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
if
self
.
device
==
"cuda"
:
self
.
init_cublas
()
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
else
:
self
.
init_attention_backend
()
def
init_torch_distributed
(
self
):
logger
.
info
(
"Init torch distributed begin."
)
# Init torch distributed
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
"Init nccl begin."
)
if
self
.
device
==
"cuda"
:
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
backend
=
"nccl"
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
if
self
.
server_args
.
dist_init_addr
:
nccl
_init_method
=
f
"tcp://
{
self
.
server_args
.
dist_init_addr
}
"
dist
_init_method
=
f
"tcp://
{
self
.
server_args
.
dist_init_addr
}
"
else
:
nccl
_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl
_port
}
"
dist
_init_method
=
f
"tcp://127.0.0.1:
{
self
.
dist
_port
}
"
set_custom_all_reduce
(
not
self
.
server_args
.
disable_custom_all_reduce
)
init_distributed_environment
(
backend
=
"nccl"
,
backend
=
backend
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
nccl
_init_method
,
distributed_init_method
=
dist
_init_method
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
min_per_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
self
.
tp_group
=
get_tp_group
()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)):
if
self
.
device
==
"cuda"
and
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)
):
self
.
server_args
.
disable_cuda_graph_padding
=
True
logger
.
info
(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
...
...
@@ -172,7 +179,7 @@ class ModelRunner:
# Check memory for tensor parallelism
if
self
.
tp_size
>
1
:
local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
if
min_per_gpu_memory
<
local_gpu_memory
*
0.9
:
raise
ValueError
(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
...
...
@@ -182,23 +189,22 @@ class ModelRunner:
def
load_model
(
self
):
logger
.
info
(
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
# This can reduce thread conflicts and speed up weight loading.
torch
.
set_num_threads
(
1
)
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
logger
.
info
(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self
.
server_args
.
dtype
=
"float16"
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
if
self
.
device
==
"cuda"
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
logger
.
info
(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self
.
server_args
.
dtype
=
"float16"
if
torch
.
cuda
.
get_device_capability
()[
1
]
<
5
:
raise
RuntimeError
(
"SGLang only supports sm75 and above."
)
# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader
()
self
.
device_config
=
DeviceConfig
()
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
self
.
vllm_model_config
=
VllmModelConfig
(
model
=
self
.
server_args
.
model_path
,
...
...
@@ -220,7 +226,7 @@ class ModelRunner:
self
.
model
=
get_model
(
model_config
=
self
.
vllm_model_config
,
load_config
=
self
.
load_config
,
device_config
=
self
.
device
_config
,
device_config
=
DeviceConfig
(
self
.
device
)
,
parallel_config
=
None
,
scheduler_config
=
None
,
lora_config
=
None
,
...
...
@@ -240,7 +246,7 @@ class ModelRunner:
f
"Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
def
update_weights
(
self
,
model_path
:
str
,
load_format
:
str
):
...
...
@@ -254,10 +260,10 @@ class ModelRunner:
logger
.
info
(
f
"Update weights begin. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
target_device
=
torch
.
device
(
self
.
device
_config
.
device
)
target_device
=
torch
.
device
(
self
.
device
)
try
:
# TODO: Use a better method to check this
...
...
@@ -343,7 +349,7 @@ class ModelRunner:
def
profile_max_num_token
(
self
,
total_gpu_memory
:
int
):
available_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
@@ -409,11 +415,10 @@ class ModelRunner:
4096
,
)
device
=
"cuda"
self
.
req_to_token_pool
=
ReqToTokenPool
(
size
=
max_num_reqs
+
1
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
device
,
device
=
self
.
device
,
)
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
@@ -425,7 +430,7 @@ class ModelRunner:
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
device
,
device
=
self
.
device
,
)
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
...
...
@@ -434,11 +439,11 @@ class ModelRunner:
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
device
,
device
=
self
.
device
,
)
logger
.
info
(
f
"Memory pool end. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
def
init_cublas
(
self
):
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
8275049c
...
...
@@ -37,6 +37,9 @@ class SamplingBatchInfo:
linear_penalties
:
torch
.
Tensor
=
None
scaling_penalties
:
torch
.
Tensor
=
None
# Device
device
:
str
=
"cuda"
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
reqs
=
batch
.
reqs
...
...
@@ -62,6 +65,7 @@ class SamplingBatchInfo:
min_ps
=
min_ps
,
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
vocab_size
=
vocab_size
,
device
=
batch
.
input_ids
.
device
,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
...
...
@@ -75,7 +79,7 @@ class SamplingBatchInfo:
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
batch
=
batch
,
device
=
"cuda"
,
device
=
batch
.
input_ids
.
device
,
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
...
...
@@ -107,7 +111,7 @@ class SamplingBatchInfo:
self
.
linear_penalties
=
torch
.
zeros
(
(
bs
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
...
...
@@ -119,7 +123,10 @@ class SamplingBatchInfo:
if
has_regex
:
self
.
vocab_mask
=
torch
.
zeros
(
len
(
self
.
temperatures
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cuda"
len
(
self
.
temperatures
),
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
,
)
for
i
,
regex_fsm
in
enumerate
(
self
.
regex_fsms
):
if
regex_fsm
is
not
None
:
...
...
@@ -144,7 +151,12 @@ class SamplingBatchInfo:
@
staticmethod
def
merge_bias_tensor
(
lhs
:
torch
.
Tensor
,
rhs
:
torch
.
Tensor
,
bs1
:
int
,
bs2
:
int
,
default
:
int
=
0
lhs
:
torch
.
Tensor
,
rhs
:
torch
.
Tensor
,
bs1
:
int
,
bs2
:
int
,
device
:
str
,
default
:
int
=
0
,
):
# bias tensor can be None
if
lhs
is
not
None
or
rhs
is
not
None
:
...
...
@@ -155,9 +167,9 @@ class SamplingBatchInfo:
shape
,
dtype
=
rhs
.
shape
[
1
:],
rhs
.
dtype
with
torch
.
dtype
(
dtype
):
if
lhs
is
None
:
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
"cuda"
).
fill_
(
default
)
lhs
=
torch
.
empty
((
bs1
,
*
shape
),
device
=
device
).
fill_
(
default
)
if
rhs
is
None
:
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
"cuda"
).
fill_
(
default
)
rhs
=
torch
.
empty
((
bs2
,
*
shape
),
device
=
device
).
fill_
(
default
)
return
torch
.
cat
([
lhs
,
rhs
])
return
None
...
...
@@ -176,5 +188,5 @@ class SamplingBatchInfo:
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
)
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
)
,
self
.
device
)
python/sglang/srt/server_args.py
View file @
8275049c
...
...
@@ -36,6 +36,7 @@ class ServerArgs:
skip_tokenizer_init
:
bool
=
False
load_format
:
str
=
"auto"
dtype
:
str
=
"auto"
device
:
str
=
"cuda"
kv_cache_dtype
:
str
=
"auto"
trust_remote_code
:
bool
=
True
context_length
:
Optional
[
int
]
=
None
...
...
@@ -237,6 +238,13 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.
\n
'
'* "float32" for FP32 precision.'
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
],
help
=
"The device type."
,
)
parser
.
add_argument
(
"--kv-cache-dtype"
,
type
=
str
,
...
...
python/sglang/srt/utils.py
View file @
8275049c
...
...
@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0):
return
wrapper
def
get_available_gpu_memory
(
gpu_id
,
distributed
=
False
):
def
get_available_gpu_memory
(
device
,
gpu_id
,
distributed
=
False
):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
num_gpus
=
torch
.
cuda
.
device_count
()
assert
gpu_id
<
num_gpus
if
device
==
"cuda"
:
num_gpus
=
torch
.
cuda
.
device_count
()
assert
gpu_id
<
num_gpus
if
torch
.
cuda
.
current_device
()
!=
gpu_id
:
print
(
f
"WARNING: current device is not
{
gpu_id
}
, but
{
torch
.
cuda
.
current_device
()
}
, "
,
"which may cause useless memory allocation for torch CUDA context."
,
)
if
torch
.
cuda
.
current_device
()
!=
gpu_id
:
print
(
f
"WARNING: current device is not
{
gpu_id
}
, but
{
torch
.
cuda
.
current_device
()
}
, "
,
"which may cause useless memory allocation for torch CUDA context."
,
)
torch
.
cuda
.
empty_cache
()
free_gpu_memory
,
_
=
torch
.
cuda
.
mem_get_info
(
gpu_id
)
torch
.
cuda
.
empty_cache
()
free_gpu_memory
,
_
=
torch
.
cuda
.
mem_get_info
(
gpu_id
)
elif
device
==
"xpu"
:
num_gpus
=
torch
.
xpu
.
device_count
()
assert
gpu_id
<
num_gpus
if
torch
.
xpu
.
current_device
()
!=
gpu_id
:
print
(
f
"WARNING: current device is not
{
gpu_id
}
, but
{
torch
.
xpu
.
current_device
()
}
, "
,
"which may cause useless memory allocation for torch XPU context."
,
)
torch
.
xpu
.
empty_cache
()
used_memory
=
torch
.
xpu
.
memory_allocated
()
total_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
gpu_id
).
total_memory
free_gpu_memory
=
total_gpu_memory
-
used_memory
if
distributed
:
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
torch
.
device
(
"cuda"
,
gpu_id
)
torch
.
device
(
device
,
gpu_id
)
)
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
)
free_gpu_memory
=
tensor
.
item
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment