Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5a261bd0
Unverified
Commit
5a261bd0
authored
Aug 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 16, 2024
Browse files
Fix the deadlock in multi-node tp (#1122)
parent
6aa8ad14
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
54 additions
and
16 deletions
+54
-16
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+3
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+6
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+13
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+14
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+3
-2
No files found.
benchmark/gsm8k/bench_sglang.py
View file @
5a261bd0
...
...
@@ -64,7 +64,9 @@ def main(args):
@
sgl
.
function
def
few_shot_gsm8k
(
s
,
question
):
s
+=
few_shot_examples
+
question
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
])
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
]
)
#####################################
########## SGL Program End ##########
...
...
python/sglang/srt/layers/logits_processor.py
View file @
5a261bd0
...
...
@@ -67,10 +67,12 @@ class LogitsMetadata:
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
do_tensor_parallel_all_gather
=
(
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
)
def
_get_normalized_prompt_logprobs
(
self
,
input_token_logprobs
,
logits_metadata
:
LogitsMetadata
...
...
@@ -159,7 +161,7 @@ class LogitsProcessor(nn.Module):
last_hidden
=
hidden_states
[
last_index
]
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
if
self
.
tp_size
>
1
:
if
self
.
do_tensor_parallel_all_gather
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
...
...
@@ -204,7 +206,7 @@ class LogitsProcessor(nn.Module):
)
else
:
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
if
self
.
do_tensor_parallel_all_gather
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
5a261bd0
...
...
@@ -21,7 +21,9 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
from
vllm.distributed
import
get_tensor_model_parallel_group
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.global_config
import
global_config
...
...
@@ -724,7 +726,7 @@ class ScheduleBatch:
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
def
sample
(
self
,
logits
:
torch
.
Tensor
,
is_multi_node_tp
=
False
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits
=
logits
.
contiguous
()
...
...
@@ -779,6 +781,16 @@ class ScheduleBatch:
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
if
is_multi_node_tp
:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch
.
distributed
.
all_reduce
(
batch_next_token_ids
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
().
device_group
,
)
return
batch_next_token_ids
...
...
python/sglang/srt/managers/tp_worker.py
View file @
5a261bd0
...
...
@@ -85,10 +85,6 @@ class ModelTpServer:
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
# Chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
...
...
@@ -175,6 +171,10 @@ class ModelTpServer:
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
# Chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
...
...
@@ -444,7 +444,9 @@ class ModelTpServer:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -603,7 +605,9 @@ class ModelTpServer:
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
5a261bd0
...
...
@@ -142,7 +142,7 @@ class CudaGraphRunner:
set_torch_compile_config
()
def
can_run
(
self
,
batch_size
):
return
batch_size
<
self
.
max_bs
return
batch_size
<
=
self
.
max_bs
def
capture
(
self
,
batch_size_list
):
self
.
batch_size_list
=
batch_size_list
...
...
@@ -239,12 +239,23 @@ class CudaGraphRunner:
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
run_once
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
self
.
graph_memory_pool
,
stream
=
stream
):
out
=
run_once
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
None
,
out
,
flashinfer_decode_wrapper
...
...
@@ -278,7 +289,9 @@ class CudaGraphRunner:
)
# Replay
torch
.
cuda
.
synchronize
()
self
.
graphs
[
bs
].
replay
()
torch
.
cuda
.
synchronize
()
output
=
self
.
output_buffers
[
bs
]
# Unpad
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5a261bd0
...
...
@@ -38,6 +38,7 @@ from vllm.distributed import (
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
...
...
@@ -112,10 +113,13 @@ class ModelRunner:
distributed_init_method
=
nccl_init_method
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
self
.
tp_group
=
get_tp_group
()
total_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
self
.
tp_group
=
get_tp_group
()
self
.
is_multi_node_tp
=
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)
)
if
self
.
tp_size
>
1
:
total_local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
...
...
python/sglang/srt/models/grok.py
View file @
5a261bd0
...
...
@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self
.
lm_head
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment