Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5a261bd0
Unverified
Commit
5a261bd0
authored
Aug 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 16, 2024
Browse files
Fix the deadlock in multi-node tp (#1122)
parent
6aa8ad14
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
54 additions
and
16 deletions
+54
-16
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+3
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+6
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+13
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+14
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+3
-2
No files found.
benchmark/gsm8k/bench_sglang.py
View file @
5a261bd0
...
@@ -64,7 +64,9 @@ def main(args):
...
@@ -64,7 +64,9 @@ def main(args):
@
sgl
.
function
@
sgl
.
function
def
few_shot_gsm8k
(
s
,
question
):
def
few_shot_gsm8k
(
s
,
question
):
s
+=
few_shot_examples
+
question
s
+=
few_shot_examples
+
question
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
])
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
]
)
#####################################
#####################################
########## SGL Program End ##########
########## SGL Program End ##########
...
...
python/sglang/srt/layers/logits_processor.py
View file @
5a261bd0
...
@@ -67,10 +67,12 @@ class LogitsMetadata:
...
@@ -67,10 +67,12 @@ class LogitsMetadata:
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
do_tensor_parallel_all_gather
=
(
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
)
def
_get_normalized_prompt_logprobs
(
def
_get_normalized_prompt_logprobs
(
self
,
input_token_logprobs
,
logits_metadata
:
LogitsMetadata
self
,
input_token_logprobs
,
logits_metadata
:
LogitsMetadata
...
@@ -159,7 +161,7 @@ class LogitsProcessor(nn.Module):
...
@@ -159,7 +161,7 @@ class LogitsProcessor(nn.Module):
last_hidden
=
hidden_states
[
last_index
]
last_hidden
=
hidden_states
[
last_index
]
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
if
self
.
tp_size
>
1
:
if
self
.
do_tensor_parallel_all_gather
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
...
@@ -204,7 +206,7 @@ class LogitsProcessor(nn.Module):
...
@@ -204,7 +206,7 @@ class LogitsProcessor(nn.Module):
)
)
else
:
else
:
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
if
self
.
do_tensor_parallel_all_gather
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
5a261bd0
...
@@ -21,7 +21,9 @@ from dataclasses import dataclass
...
@@ -21,7 +21,9 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
from
vllm.distributed
import
get_tensor_model_parallel_group
import
sglang.srt.sampling.penaltylib
as
penaltylib
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
...
@@ -724,7 +726,7 @@ class ScheduleBatch:
...
@@ -724,7 +726,7 @@ class ScheduleBatch:
)
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
def
sample
(
self
,
logits
:
torch
.
Tensor
,
is_multi_node_tp
=
False
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
# Post process logits
logits
=
logits
.
contiguous
()
logits
=
logits
.
contiguous
()
...
@@ -779,6 +781,16 @@ class ScheduleBatch:
...
@@ -779,6 +781,16 @@ class ScheduleBatch:
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
if
is_multi_node_tp
:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch
.
distributed
.
all_reduce
(
batch_next_token_ids
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
().
device_group
,
)
return
batch_next_token_ids
return
batch_next_token_ids
...
...
python/sglang/srt/managers/tp_worker.py
View file @
5a261bd0
...
@@ -85,10 +85,6 @@ class ModelTpServer:
...
@@ -85,10 +85,6 @@ class ModelTpServer:
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
# Chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
# Init model and tokenizer
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
model_path
,
...
@@ -175,6 +171,10 @@ class ModelTpServer:
...
@@ -175,6 +171,10 @@ class ModelTpServer:
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
self
.
last_stats_tic
=
time
.
time
()
# Chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
self
.
regex_fsm_cache
=
FSMCache
(
...
@@ -444,7 +444,9 @@ class ModelTpServer:
...
@@ -444,7 +444,9 @@ class ModelTpServer:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
# Move logprobs to cpu
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
if
output
.
next_token_logprobs
is
not
None
:
...
@@ -603,7 +605,9 @@ class ModelTpServer:
...
@@ -603,7 +605,9 @@ class ModelTpServer:
# Forward and sample the next tokens
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
# Move logprobs to cpu
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
if
output
.
next_token_logprobs
is
not
None
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
5a261bd0
...
@@ -142,7 +142,7 @@ class CudaGraphRunner:
...
@@ -142,7 +142,7 @@ class CudaGraphRunner:
set_torch_compile_config
()
set_torch_compile_config
()
def
can_run
(
self
,
batch_size
):
def
can_run
(
self
,
batch_size
):
return
batch_size
<
self
.
max_bs
return
batch_size
<
=
self
.
max_bs
def
capture
(
self
,
batch_size_list
):
def
capture
(
self
,
batch_size_list
):
self
.
batch_size_list
=
batch_size_list
self
.
batch_size_list
=
batch_size_list
...
@@ -239,12 +239,23 @@ class CudaGraphRunner:
...
@@ -239,12 +239,23 @@ class CudaGraphRunner:
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
for
_
in
range
(
2
):
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
run_once
()
run_once
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
self
.
graph_memory_pool
,
stream
=
stream
):
with
torch
.
cuda
.
graph
(
graph
,
pool
=
self
.
graph_memory_pool
,
stream
=
stream
):
out
=
run_once
()
out
=
run_once
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
self
.
graph_memory_pool
=
graph
.
pool
()
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
None
,
out
,
flashinfer_decode_wrapper
return
graph
,
None
,
out
,
flashinfer_decode_wrapper
...
@@ -278,7 +289,9 @@ class CudaGraphRunner:
...
@@ -278,7 +289,9 @@ class CudaGraphRunner:
)
)
# Replay
# Replay
torch
.
cuda
.
synchronize
()
self
.
graphs
[
bs
].
replay
()
self
.
graphs
[
bs
].
replay
()
torch
.
cuda
.
synchronize
()
output
=
self
.
output_buffers
[
bs
]
output
=
self
.
output_buffers
[
bs
]
# Unpad
# Unpad
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5a261bd0
...
@@ -38,6 +38,7 @@ from vllm.distributed import (
...
@@ -38,6 +38,7 @@ from vllm.distributed import (
init_distributed_environment
,
init_distributed_environment
,
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
...
@@ -112,10 +113,13 @@ class ModelRunner:
...
@@ -112,10 +113,13 @@ class ModelRunner:
distributed_init_method
=
nccl_init_method
,
distributed_init_method
=
nccl_init_method
,
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
self
.
tp_group
=
get_tp_group
()
total_gpu_memory
=
get_available_gpu_memory
(
total_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
)
self
.
tp_group
=
get_tp_group
()
self
.
is_multi_node_tp
=
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
total_local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
total_local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
...
...
python/sglang/srt/models/grok.py
View file @
5a261bd0
...
@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
lm_head
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
# Monkey patch _prepare_weights to load pre-sharded weights
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment