Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
5a261bd0
"Plugson/src/Core/ventoy_define.h" did not exist on "43e8ec57857e3786b40eb4208729e9fcf9826000"
Unverified
Commit
5a261bd0
authored
Aug 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 16, 2024
Browse files
Fix the deadlock in multi-node tp (#1122)
parent
6aa8ad14
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
54 additions
and
16 deletions
+54
-16
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+3
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+6
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+13
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-6
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+14
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+3
-2
No files found.
benchmark/gsm8k/bench_sglang.py
View file @
5a261bd0
...
...
@@ -64,7 +64,9 @@ def main(args):
@
sgl
.
function
def
few_shot_gsm8k
(
s
,
question
):
s
+=
few_shot_examples
+
question
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
])
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
]
)
#####################################
########## SGL Program End ##########
...
...
python/sglang/srt/layers/logits_processor.py
View file @
5a261bd0
...
...
@@ -67,10 +67,12 @@ class LogitsMetadata:
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
skip_all_gather
:
bool
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
do_tensor_parallel_all_gather
=
(
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
)
def
_get_normalized_prompt_logprobs
(
self
,
input_token_logprobs
,
logits_metadata
:
LogitsMetadata
...
...
@@ -159,7 +161,7 @@ class LogitsProcessor(nn.Module):
last_hidden
=
hidden_states
[
last_index
]
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
if
self
.
tp_size
>
1
:
if
self
.
do_tensor_parallel_all_gather
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
...
...
@@ -204,7 +206,7 @@ class LogitsProcessor(nn.Module):
)
else
:
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
if
self
.
do_tensor_parallel_all_gather
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
5a261bd0
...
...
@@ -21,7 +21,9 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
from
vllm.distributed
import
get_tensor_model_parallel_group
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.global_config
import
global_config
...
...
@@ -724,7 +726,7 @@ class ScheduleBatch:
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
def
sample
(
self
,
logits
:
torch
.
Tensor
,
is_multi_node_tp
=
False
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits
=
logits
.
contiguous
()
...
...
@@ -779,6 +781,16 @@ class ScheduleBatch:
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
if
is_multi_node_tp
:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch
.
distributed
.
all_reduce
(
batch_next_token_ids
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
get_tensor_model_parallel_group
().
device_group
,
)
return
batch_next_token_ids
...
...
python/sglang/srt/managers/tp_worker.py
View file @
5a261bd0
...
...
@@ -85,10 +85,6 @@ class ModelTpServer:
self
.
schedule_policy
=
server_args
.
schedule_policy
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
# Chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
...
...
@@ -175,6 +171,10 @@ class ModelTpServer:
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
# Chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
# Init the FSM cache for constrained generation
if
not
server_args
.
skip_tokenizer_init
:
self
.
regex_fsm_cache
=
FSMCache
(
...
...
@@ -444,7 +444,9 @@ class ModelTpServer:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -603,7 +605,9 @@ class ModelTpServer:
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
,
self
.
model_runner
.
is_multi_node_tp
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
5a261bd0
...
...
@@ -142,7 +142,7 @@ class CudaGraphRunner:
set_torch_compile_config
()
def
can_run
(
self
,
batch_size
):
return
batch_size
<
self
.
max_bs
return
batch_size
<
=
self
.
max_bs
def
capture
(
self
,
batch_size_list
):
self
.
batch_size_list
=
batch_size_list
...
...
@@ -239,12 +239,23 @@ class CudaGraphRunner:
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
run_once
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
self
.
graph_memory_pool
,
stream
=
stream
):
out
=
run_once
()
torch
.
cuda
.
synchronize
()
self
.
model_runner
.
tp_group
.
barrier
()
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
None
,
out
,
flashinfer_decode_wrapper
...
...
@@ -278,7 +289,9 @@ class CudaGraphRunner:
)
# Replay
torch
.
cuda
.
synchronize
()
self
.
graphs
[
bs
].
replay
()
torch
.
cuda
.
synchronize
()
output
=
self
.
output_buffers
[
bs
]
# Unpad
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5a261bd0
...
...
@@ -38,6 +38,7 @@ from vllm.distributed import (
init_distributed_environment
,
initialize_model_parallel
,
)
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
...
...
@@ -112,10 +113,13 @@ class ModelRunner:
distributed_init_method
=
nccl_init_method
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
self
.
tp_group
=
get_tp_group
()
total_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
self
.
tp_group
=
get_tp_group
()
self
.
is_multi_node_tp
=
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)
)
if
self
.
tp_size
>
1
:
total_local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
...
...
python/sglang/srt/models/grok.py
View file @
5a261bd0
...
...
@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self
.
lm_head
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment