Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dc1b8bcf
Unverified
Commit
dc1b8bcf
authored
Jul 05, 2024
by
Ying Sheng
Committed by
GitHub
Jul 05, 2024
Browse files
Format (#593)
parent
5a57b8ad
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
484 additions
and
353 deletions
+484
-353
benchmark/latency_throughput/bench_one.py
benchmark/latency_throughput/bench_one.py
+1
-1
benchmark/latency_throughput/bench_serving.py
benchmark/latency_throughput/bench_serving.py
+3
-2
benchmark/line_retrieval/gen_data.py
benchmark/line_retrieval/gen_data.py
+3
-3
benchmark/mmlu/bench_sglang.py
benchmark/mmlu/bench_sglang.py
+9
-5
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+28
-12
python/sglang/global_config.py
python/sglang/global_config.py
+1
-0
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+4
-2
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+3
-2
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+7
-3
python/sglang/srt/layers/fused_moe.py
python/sglang/srt/layers/fused_moe.py
+181
-167
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+9
-2
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+1
-1
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+20
-11
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+31
-14
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+8
-8
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+74
-60
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+10
-7
python/sglang/srt/server.py
python/sglang/srt/server.py
+17
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-16
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+66
-29
No files found.
benchmark/latency_throughput/bench_one.py
View file @
dc1b8bcf
...
...
@@ -92,4 +92,4 @@ if __name__ == "__main__":
print
(
ret
)
speed
=
args
.
batch_size
*
max_new_tokens
/
latency
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
speed
:.
2
f
}
token/s"
)
\ No newline at end of file
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
speed
:.
2
f
}
token/s"
)
benchmark/latency_throughput/bench_serving.py
View file @
dc1b8bcf
...
...
@@ -307,8 +307,9 @@ def main(args: argparse.Namespace):
avg_per_output_token_latency
=
np
.
mean
(
[
latency
/
output_len
for
_
,
output_len
,
latency
in
REQUEST_LATENCY
]
)
decoding_throughput
=
np
.
sum
([
output_len
for
_
,
output_len
,
_
in
REQUEST_LATENCY
])
/
benchmark_time
decoding_throughput
=
(
np
.
sum
([
output_len
for
_
,
output_len
,
_
in
REQUEST_LATENCY
])
/
benchmark_time
)
print
(
f
"Total time:
{
benchmark_time
:.
2
f
}
s"
)
print
(
f
"Request throughput:
{
args
.
num_prompts
/
benchmark_time
:.
2
f
}
requests/s"
)
...
...
benchmark/line_retrieval/gen_data.py
View file @
dc1b8bcf
...
...
@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
)
for
i
in
redirect_indices
:
target_idx
=
np
.
random
.
choice
(
min
(
i
*
2
+
100
,
num_lines
))
lines
[
i
]
=
f
"Line
{
indices
[
i
]
}
: The REGISTER_CONTENT is the same as Line
{
indices
[
target_idx
]
}
."
lines
[
i
]
=
(
f
"Line
{
indices
[
i
]
}
: The REGISTER_CONTENT is the same as Line
{
indices
[
target_idx
]
}
."
)
redirects
[
i
]
=
target_idx
# Build links and find sources
...
...
benchmark/mmlu/bench_sglang.py
View file @
dc1b8bcf
...
...
@@ -80,10 +80,12 @@ def main(args):
for
i
in
range
(
test_df
.
shape
[
0
]):
prompt_end
=
format_example
(
test_df
,
i
,
include_answer
=
False
)
arguments
.
append
({
"examples"
:
few_shot_examples
,
"question"
:
prompt_end
,
})
arguments
.
append
(
{
"examples"
:
few_shot_examples
,
"question"
:
prompt_end
,
}
)
label
=
test_df
.
iloc
[
i
,
test_df
.
shape
[
1
]
-
1
]
labels
.
append
(
label
)
...
...
@@ -134,7 +136,9 @@ def main(args):
pt
=
0
for
subject
,
num_qs
in
zip
(
subjects
[:
args
.
nsub
],
num_questions
):
print
(
f
"subject:
{
subject
}
, #q:
{
num_qs
}
, acc:
{
np
.
mean
(
cors
[
pt
:
pt
+
num_qs
]):.
3
f
}
"
)
print
(
f
"subject:
{
subject
}
, #q:
{
num_qs
}
, acc:
{
np
.
mean
(
cors
[
pt
:
pt
+
num_qs
]):.
3
f
}
"
)
pt
+=
num_qs
assert
pt
==
len
(
cors
)
weighted_acc
=
np
.
mean
(
cors
)
...
...
python/sglang/bench_latency.py
View file @
dc1b8bcf
...
...
@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
for
i
in
range
(
len
(
prompts
)):
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
...
...
@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
def
prepare_extend_inputs
(
bench_args
,
input_ids
,
reqs
,
model_runner
):
for
i
in
range
(
len
(
reqs
)):
req
=
reqs
[
i
]
req
.
input_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
req
.
input_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
req
.
prefix_indices
=
model_runner
.
req_to_token_pool
.
req_to_token
[
i
,
:
bench_args
.
cut_len
i
,
:
bench_args
.
cut_len
]
return
reqs
...
...
@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
tree_cache
=
None
)
tree_cache
=
None
,
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
,
None
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
...
...
@@ -212,7 +213,9 @@ def latency_test(
# Load the model
model_runner
,
tokenizer
=
load_model
(
server_args
,
tp_rank
)
print
(
f
"max_batch_size=
{
model_runner
.
max_total_num_tokens
//
(
bench_args
.
input_len
+
bench_args
.
output_len
)
}
"
)
print
(
f
"max_batch_size=
{
model_runner
.
max_total_num_tokens
//
(
bench_args
.
input_len
+
bench_args
.
output_len
)
}
"
)
# Prepare inputs
reqs
=
prepare_synthetic_inputs
(
bench_args
,
tokenizer
)
...
...
@@ -232,7 +235,9 @@ def latency_test(
prefill_latency
=
time
.
time
()
-
tic
tot_latency
+=
prefill_latency
throughput
=
bench_args
.
input_len
*
bench_args
.
batch_size
/
prefill_latency
rank_print
(
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
rank_print
(
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
# Decode
for
i
in
range
(
output_len
):
...
...
@@ -243,13 +248,24 @@ def latency_test(
latency
=
time
.
time
()
-
tic
tot_latency
+=
latency
throughput
=
bench_args
.
batch_size
/
latency
if
i
<
5
:
rank_print
(
f
"Decode. latency:
{
latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
if
i
<
5
:
rank_print
(
f
"Decode. latency:
{
latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
avg_decode_latency
=
(
tot_latency
-
prefill_latency
)
/
output_len
avg_decode_throughput
=
bench_args
.
batch_size
/
avg_decode_latency
rank_print
(
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
)
throughput
=
(
bench_args
.
input_len
+
bench_args
.
output_len
)
*
bench_args
.
batch_size
/
tot_latency
rank_print
(
f
"Total. latency:
{
tot_latency
:
6.3
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
rank_print
(
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
)
throughput
=
(
(
bench_args
.
input_len
+
bench_args
.
output_len
)
*
bench_args
.
batch_size
/
tot_latency
)
rank_print
(
f
"Total. latency:
{
tot_latency
:
6.3
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
# Warm up
run_once
(
4
)
...
...
@@ -298,4 +314,4 @@ if __name__ == "__main__":
format
=
"%(message)s"
,
)
main
(
server_args
,
bench_args
)
\ No newline at end of file
main
(
server_args
,
bench_args
)
python/sglang/global_config.py
View file @
dc1b8bcf
...
...
@@ -39,4 +39,5 @@ class GlobalConfig:
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
global_config
=
GlobalConfig
()
python/sglang/lang/ir.py
View file @
dc1b8bcf
...
...
@@ -185,8 +185,10 @@ class SglFunction:
batch_kwargs
=
[
{
self
.
arg_names
[
i
]:
v
for
i
,
v
in
enumerate
(
arg_values
)}
for
arg_values
in
batch_kwargs
if
isinstance
(
arg_values
,
(
list
,
tuple
))
and
len
(
self
.
arg_names
)
-
len
(
self
.
arg_defaults
)
<=
len
(
arg_values
)
<=
len
(
self
.
arg_names
)
if
isinstance
(
arg_values
,
(
list
,
tuple
))
and
len
(
self
.
arg_names
)
-
len
(
self
.
arg_defaults
)
<=
len
(
arg_values
)
<=
len
(
self
.
arg_names
)
]
# Ensure to raise an exception if the number of arguments mismatch
if
len
(
batch_kwargs
)
!=
num_programs
:
...
...
python/sglang/srt/constrained/__init__.py
View file @
dc1b8bcf
...
...
@@ -5,13 +5,14 @@ from pydantic import BaseModel
try
:
from
outlines.caching
import
cache
as
disk_cache
from
outlines.fsm.guide
import
RegexGuide
from
outlines.caching
import
disable_cache
from
outlines.fsm.guide
import
RegexGuide
from
outlines.fsm.regex
import
FSMInfo
,
make_byte_level_fsm
,
make_deterministic_fsm
from
outlines.models.transformers
import
TransformerTokenizer
except
ImportError
as
e
:
print
(
f
'
\n
Error:
{
e
}
. Please install a new version of outlines by `pip install "outlines>=0.0.44"`
\n
'
)
print
(
f
'
\n
Error:
{
e
}
. Please install a new version of outlines by `pip install "outlines>=0.0.44"`
\n
'
)
raise
try
:
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
dc1b8bcf
...
...
@@ -264,7 +264,9 @@ class TiktokenTokenizer:
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
...
...
@@ -297,5 +299,7 @@ class SentencePieceTokenizer:
return
self
.
tokenizer
.
decode
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
\ No newline at end of file
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
python/sglang/srt/layers/fused_moe.py
View file @
dc1b8bcf
...
...
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
triton
import
triton.language
as
tl
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
...
...
@@ -108,12 +107,16 @@ def fused_moe_kernel(
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
b_ptrs
=
(
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
)
if
use_fp8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
...
...
@@ -129,13 +132,12 @@ def fused_moe_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_fp8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
...
...
@@ -146,9 +148,7 @@ def fused_moe_kernel(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_fp8
:
...
...
@@ -158,15 +158,14 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
...
...
@@ -205,32 +204,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
)
->
None
:
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
fused_moe_kernel
[
grid
](
A
,
...
...
@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
...
...
@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int,
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
...
...
@@ -319,35 +325,35 @@ def get_default_config(
)
->
Dict
[
str
,
int
]:
if
dtype
==
"float8"
:
config
=
{
'
BLOCK_SIZE_M
'
:
128
,
'
BLOCK_SIZE_N
'
:
256
,
'
BLOCK_SIZE_K
'
:
128
,
'
GROUP_SIZE_M
'
:
32
,
"
BLOCK_SIZE_M
"
:
128
,
"
BLOCK_SIZE_N
"
:
256
,
"
BLOCK_SIZE_K
"
:
128
,
"
GROUP_SIZE_M
"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
"num_stages"
:
4
,
}
if
M
<=
E
:
config
=
{
'
BLOCK_SIZE_M
'
:
64
,
'
BLOCK_SIZE_N
'
:
128
,
'
BLOCK_SIZE_K
'
:
128
,
'
GROUP_SIZE_M
'
:
1
,
"
BLOCK_SIZE_M
"
:
64
,
"
BLOCK_SIZE_N
"
:
128
,
"
BLOCK_SIZE_K
"
:
128
,
"
GROUP_SIZE_M
"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
"num_stages"
:
4
,
}
else
:
config
=
{
'
BLOCK_SIZE_M
'
:
64
,
'
BLOCK_SIZE_N
'
:
64
,
'
BLOCK_SIZE_K
'
:
32
,
'
GROUP_SIZE_M
'
:
8
"
BLOCK_SIZE_M
"
:
64
,
"
BLOCK_SIZE_N
"
:
64
,
"
BLOCK_SIZE_K
"
:
32
,
"
GROUP_SIZE_M
"
:
8
,
}
if
M
<=
E
:
config
=
{
'
BLOCK_SIZE_M
'
:
16
,
'
BLOCK_SIZE_N
'
:
32
,
'
BLOCK_SIZE_K
'
:
64
,
'
GROUP_SIZE_M
'
:
1
"
BLOCK_SIZE_M
"
:
16
,
"
BLOCK_SIZE_N
"
:
32
,
"
BLOCK_SIZE_K
"
:
64
,
"
GROUP_SIZE_M
"
:
1
,
}
return
config
...
...
@@ -358,23 +364,17 @@ def fused_topk(
topk
:
int
,
renormalize
:
bool
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
topk_weights
,
topk_ids
,
...
...
@@ -388,27 +388,27 @@ def fused_topk(
return
topk_weights
,
topk_ids
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
...
...
@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor,
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
...
...
@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor,
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
)
intermediate_cache1
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache3
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
compute_type
=
(
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
,
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
def
fused_moe
(
...
...
@@ -532,25 +542,28 @@ def fused_moe(
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
if
hasattr
(
ops
,
"topk_softmax"
):
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
fused_topk_v0_4_3
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
override_config
=
override_config
,
use_fp8
=
use_fp8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
)
topk_weights
,
topk_ids
=
fused_topk_v0_4_3
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
override_config
=
override_config
,
use_fp8
=
use_fp8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
def
fused_topk_v0_4_3
(
...
...
@@ -560,6 +573,7 @@ def fused_topk_v0_4_3(
renormalize
:
bool
,
):
import
vllm._moe_C
as
moe_kernels
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
...
...
@@ -579,4 +593,4 @@ def fused_topk_v0_4_3(
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
\ No newline at end of file
return
topk_weights
,
topk_ids
python/sglang/srt/layers/radix_attention.py
View file @
dc1b8bcf
"""Radix attention."""
import
numpy
as
np
import
torch
from
torch
import
nn
...
...
@@ -11,8 +12,13 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
:
int
,
head_dim
:
int
,
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
logit_cap
:
int
=
-
1
self
,
num_heads
:
int
,
head_dim
:
int
,
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
logit_cap
:
int
=
-
1
,
):
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
...
...
@@ -112,6 +118,7 @@ class RadixAttention(nn.Module):
)
from
flashinfer.cascade
import
merge_state
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
dc1b8bcf
...
...
@@ -99,4 +99,4 @@ def start_controller_process(
except
Exception
:
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
finally
:
kill_parent_process
()
\ No newline at end of file
kill_parent_process
()
python/sglang/srt/managers/controller/model_runner.py
View file @
dc1b8bcf
...
...
@@ -127,7 +127,7 @@ class InputMetadata:
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
1
,
)
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
...
...
@@ -140,7 +140,7 @@ class InputMetadata:
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
def
init_extend_args
(
self
):
...
...
@@ -228,7 +228,7 @@ class InputMetadata:
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
head_dim
model_runner
.
model_config
.
head_dim
,
)
return
ret
...
...
@@ -269,7 +269,7 @@ class ModelRunner:
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
nccl_init_method
distributed_init_method
=
nccl_init_method
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
total_gpu_memory
=
get_available_gpu_memory
(
...
...
@@ -341,7 +341,13 @@ class ModelRunner:
)
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
cell_size
=
(
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
)
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
)
...
...
@@ -384,15 +390,16 @@ class ModelRunner:
def
init_flash_infer
(
self
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
from
flashinfer
import
(
BatchPrefillWithRaggedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)):
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
...
...
@@ -400,8 +407,8 @@ class ModelRunner:
workspace_buffers
=
torch
.
empty
(
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
self
.
flashinfer_prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
1
],
"NHD"
...
...
@@ -410,7 +417,9 @@ class ModelRunner:
workspace_buffers
[
2
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
else
:
self
.
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_paged
=
None
self
.
flashinfer_prefill_wrapper_ragged
=
(
self
.
flashinfer_prefill_wrapper_paged
)
=
None
self
.
flashinfer_decode_wrapper
=
None
@
torch
.
inference_mode
()
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
dc1b8bcf
...
...
@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.server_args
import
ModelPortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
connect_rpyc_service
,
get_int_token_logit_bias
,
is_multimodal_model
,
set_random_seed
,
start_rpyc_service_process
,
connect_rpyc_service
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
...
...
@@ -368,9 +368,11 @@ class ModelTpServer:
if
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
and
(
req
.
extend_input_len
+
new_batch_input_tokens
<=
self
.
max_prefill_tokens
or
len
(
can_run_list
)
==
0
)
and
(
req
.
extend_input_len
+
new_batch_input_tokens
<=
self
.
max_prefill_tokens
or
len
(
can_run_list
)
==
0
)
):
delta
=
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
available_size
+=
delta
...
...
@@ -452,7 +454,9 @@ class ModelTpServer:
next_token_ids
,
].
tolist
()
output
.
prefill_token_logprobs
=
output
.
prefill_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
output
.
normalized_prompt_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
(
output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
else
:
...
...
@@ -582,7 +586,9 @@ class ModelTpServer:
req
.
check_finished
()
if
req
.
return_logprob
:
req
.
decode_token_logprobs
.
append
((
next_token_logprobs
[
i
],
next_token_id
))
req
.
decode_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
...
...
@@ -759,16 +765,27 @@ class ModelTpClient:
with
ThreadPoolExecutor
(
self
.
tp_size
)
as
executor
:
# Launch model processes
if
server_args
.
nnodes
==
1
:
self
.
procs
=
list
(
executor
.
map
(
lambda
args
:
start_rpyc_service_process
(
*
args
),
[(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
))
self
.
procs
=
list
(
executor
.
map
(
lambda
args
:
start_rpyc_service_process
(
*
args
),
[
(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
)
)
addrs
=
[(
"localhost"
,
p
)
for
p
in
model_port_args
.
model_tp_ports
]
else
:
addrs
=
[(
ip
,
port
)
for
ip
,
port
in
zip
(
model_port_args
.
model_tp_ips
,
model_port_args
.
model_tp_ports
)]
self
.
model_services
=
list
(
executor
.
map
(
lambda
args
:
connect_rpyc_service
(
*
args
),
addrs
))
addrs
=
[
(
ip
,
port
)
for
ip
,
port
in
zip
(
model_port_args
.
model_tp_ips
,
model_port_args
.
model_tp_ports
)
]
self
.
model_services
=
list
(
executor
.
map
(
lambda
args
:
connect_rpyc_service
(
*
args
),
addrs
)
)
# Init model
def
init_model
(
i
):
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
dc1b8bcf
...
...
@@ -334,15 +334,15 @@ class TokenizerManager:
ret
[
"meta_info"
][
"decode_token_logprobs"
],
return_text_in_logprobs
)
if
top_logprobs_num
>
0
:
ret
[
"meta_info"
][
"prefill
_top_logprobs
"
]
=
self
.
detokenize_top_logprobs_tokens
(
ret
[
"meta_info"
][
"prefill_top_logprobs"
],
return_text_in_logprobs
ret
[
"meta_info"
][
"prefill_top_logprobs"
]
=
(
self
.
detokenize
_top_logprobs
_tokens
(
ret
[
"meta_info"
][
"prefill_top_logprobs"
],
return_text_in_logprobs
)
)
ret
[
"meta_info"
][
"decod
e_top_logprobs
"
]
=
self
.
detokenize_top_logprobs_tokens
(
ret
[
"meta_info"
][
"decode_top_logprobs"
],
return_text_in_logprobs
ret
[
"meta_info"
][
"decode_top_logprobs"
]
=
(
self
.
detokeniz
e_top_logprobs
_tokens
(
ret
[
"meta_info"
][
"decode_top_logprobs"
],
return_text_in_logprobs
)
)
return
ret
...
...
python/sglang/srt/models/gemma2.py
View file @
dc1b8bcf
...
...
@@ -5,19 +5,23 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
torch
from
torch
import
nn
from
transformers
import
Gemma2Config
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
# FIXME: temporary solution, remove after next vllm release
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.activation
import
GeluAndMul
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -26,8 +30,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
# FIXME: temporary solution, remove after next vllm release
from
vllm.model_executor.custom_op
import
CustomOp
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
...
...
@@ -76,13 +78,19 @@ class GemmaRMSNorm(CustomOp):
# FIXME: temporary solution, remove after next vllm release
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
class
GemmaRotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
int64
).
float
()
/
self
.
rotary_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
int64
).
float
()
/
self
.
rotary_dim
)
)
return
inv_freq
...
...
@@ -98,18 +106,17 @@ class Gemma2MLP(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
if
not
(
hidden_act
==
hidden_activation
==
"gelu_pytorch_tanh"
):
raise
ValueError
(
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
"`gelu_pytorch_tanh`."
)
self
.
act_fn
=
GeluAndMul
(
approximate
=
"tanh"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -121,17 +128,19 @@ class Gemma2MLP(nn.Module):
class
Gemma2Attention
(
nn
.
Module
):
def
__init__
(
self
,
layer_idx
:
int
,
config
:
Gemma2Config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
max_position_embeddings
:
int
,
rope_theta
:
float
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
def
__init__
(
self
,
layer_idx
:
int
,
config
:
Gemma2Config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
max_position_embeddings
:
int
,
rope_theta
:
float
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
config
=
config
...
...
@@ -183,15 +192,16 @@ class Gemma2Attention(nn.Module):
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window
=
(
layer_idx
%
2
==
1
and
config
.
sliding_window
is
not
None
)
use_sliding_window
=
layer_idx
%
2
==
1
and
config
.
sliding_window
is
not
None
del
use_sliding_window
# Unused.
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_idx
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_idx
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
)
def
forward
(
self
,
...
...
@@ -238,14 +248,16 @@ class Gemma2DecoderLayer(nn.Module):
hidden_activation
=
config
.
hidden_activation
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_feedforward_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_feedforward_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_feedforward_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_feedforward_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
...
...
@@ -258,8 +270,7 @@ class Gemma2DecoderLayer(nn.Module):
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -268,7 +279,8 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
=
self
.
pre_feedforward_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_feedforward_layernorm
(
hidden_states
)
return
hidden_states
,
residual
...
...
@@ -289,10 +301,12 @@ class Gemma2Model(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
Gemma2DecoderLayer
(
layer_idx
,
config
,
cache_config
,
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
layers
=
nn
.
ModuleList
(
[
Gemma2DecoderLayer
(
layer_idx
,
config
,
cache_config
,
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# Normalize the embedding by sqrt(hidden_size)
...
...
@@ -392,7 +406,7 @@ class Gemma2ForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
...
...
@@ -412,8 +426,7 @@ class Gemma2ForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
...
...
@@ -421,7 +434,8 @@ class Gemma2ForCausalLM(nn.Module):
if
unloaded_params
:
raise
RuntimeError
(
"Some weights are not initialized from checkpoints: "
f
"
{
unloaded_params
}
"
)
f
"
{
unloaded_params
}
"
)
EntryClass
=
Gemma2ForCausalLM
\ No newline at end of file
EntryClass
=
Gemma2ForCausalLM
python/sglang/srt/models/llama_classification.py
View file @
dc1b8bcf
...
...
@@ -5,14 +5,12 @@ import tqdm
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
)
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.models.llama2
import
LlamaModel
...
...
@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
self
.
quant_config
=
quant_config
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
classification_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
classification_out_size
)
self
.
classification_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
classification_out_size
)
self
.
eos_token_id
=
config
.
eos_token_id
def
forward
(
...
...
@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
if
scores
.
shape
[
0
]
!=
input_metadata
.
batch_size
:
print
(
"Warning: the EOS tokens are missing in some sentences."
)
scores
=
torch
.
ones
((
input_metadata
.
batch_size
,
self
.
config
.
classification_out_size
)).
to
(
input_ids
.
device
)
scores
=
torch
.
ones
(
(
input_metadata
.
batch_size
,
self
.
config
.
classification_out_size
)
).
to
(
input_ids
.
device
)
return
LogitProcessorOutput
(
next_token_logits
=
scores
,
...
...
@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
LlamaForClassification
\ No newline at end of file
EntryClass
=
LlamaForClassification
python/sglang/srt/server.py
View file @
dc1b8bcf
...
...
@@ -51,13 +51,12 @@ from sglang.srt.utils import (
allocate_init_ports
,
assert_pkg_version
,
enable_show_time_cost
,
send_addrs_to_rank_0
,
receive_addrs
,
send_addrs_to_rank_0
,
start_rpyc_service_process
,
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
disable_disk_cache
:
disable_cache
()
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.0.8"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
)
assert_pkg_version
(
"flashinfer"
,
"0.0.8"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
)
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
...
...
@@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
ModelPortArgs
(
nccl_port
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)],
model_tp_ips
=
[
None
]
*
tp_size_local
,
model_tp_ports
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp_size_local
+
1
)],
model_tp_ports
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp_size_local
+
1
)
],
)
)
port_args
=
PortArgs
(
...
...
@@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
else
:
receive_addrs
(
model_port_args
[
0
],
server_args
)
for
i
in
range
(
tp_size_local
):
start_rpyc_service_process
(
ModelTpService
,
model_port_args
[
0
].
model_tp_ports
[
i
])
start_rpyc_service_process
(
ModelTpService
,
model_port_args
[
0
].
model_tp_ports
[
i
]
)
if
server_args
.
node_rank
!=
0
:
logger
.
info
(
f
"[node_rank=
{
server_args
.
node_rank
}
]: Listen for connections..."
)
logger
.
info
(
f
"[node_rank=
{
server_args
.
node_rank
}
]: Listen for connections..."
)
while
True
:
pass
...
...
python/sglang/srt/server_args.py
View file @
dc1b8bcf
...
...
@@ -137,17 +137,16 @@ class ServerArgs:
"--dtype"
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
choices
=
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
],
help
=
'Data type for model weights and activations.
\n\n
'
choices
=
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
],
help
=
"Data type for model weights and activations.
\n\n
"
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'
BF16 precision for BF16 models.
\n
'
"
BF16 precision for BF16 models.
\n
"
'* "half" for FP16. Recommended for AWQ quantization.
\n
'
'* "float16" is the same as "half".
\n
'
'* "bfloat16" for a balance between precision and range.
\n
'
'* "float" is shorthand for FP32 precision.
\n
'
'* "float32" for FP32 precision.'
)
'* "float32" for FP32 precision.'
,
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
,
...
...
@@ -271,19 +270,12 @@ class ServerArgs:
parser
.
add_argument
(
"--nccl-init-addr"
,
type
=
str
,
help
=
"The nccl init address of multi-node server."
help
=
"The nccl init address of multi-node server."
,
)
parser
.
add_argument
(
"--nnodes"
,
type
=
int
,
default
=
1
,
help
=
"The number of nodes."
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
"--nnodes"
,
type
=
int
,
default
=
1
,
help
=
"The number of nodes."
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
)
# Optimization/debug options
parser
.
add_argument
(
...
...
python/sglang/srt/utils.py
View file @
dc1b8bcf
...
...
@@ -432,13 +432,12 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
if
pkg_version
.
parse
(
installed_version
)
<
pkg_version
.
parse
(
min_version
):
raise
Exception
(
f
"
{
pkg
}
is installed with version
{
installed_version
}
, which "
f
"is less than the minimum required version
{
min_version
}
. "
+
message
f
"is less than the minimum required version
{
min_version
}
. "
+
message
)
except
PackageNotFoundError
:
raise
Exception
(
f
"
{
pkg
}
with minimum required version
{
min_version
}
is not installed. "
+
message
f
"
{
pkg
}
with minimum required version
{
min_version
}
is not installed. "
+
message
)
...
...
@@ -474,24 +473,40 @@ def monkey_patch_vllm_dummy_weight_loader():
"""
from
vllm.model_executor.model_loader.loader
import
(
ModelConfig
,
DeviceConfig
,
LoRAConfig
,
VisionLanguageConfig
,
ParallelConfig
,
SchedulerConfig
,
CacheConfig
,
nn
,
set_default_torch_dtype
,
_initialize_model
,
initialize_dummy_weights
,
DummyModelLoader
CacheConfig
,
DeviceConfig
,
DummyModelLoader
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
,
_initialize_model
,
initialize_dummy_weights
,
nn
,
set_default_torch_dtype
,
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
,
cache_config
)
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
,
cache_config
,
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
...
...
@@ -541,7 +556,7 @@ def get_ip_address(ifname):
ip_address
=
fcntl
.
ioctl
(
s
.
fileno
(),
0x8915
,
# SIOCGIFADDR
struct
.
pack
(
'
256s
'
,
bytes
(
ifname
[:
15
],
'
utf-8
'
))
struct
.
pack
(
"
256s
"
,
bytes
(
ifname
[:
15
],
"
utf-8
"
))
,
)[
20
:
24
]
return
socket
.
inet_ntoa
(
ip_address
)
...
...
@@ -550,44 +565,66 @@ def send_addrs_to_rank_0(model_port_args, server_args):
assert
server_args
.
node_rank
!=
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
))
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
)
)
ip_addr
=
get_ip_address
(
ifname
)
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
ip_addr
=
[
int
(
x
)
for
x
in
ip_addr
.
split
(
"."
)]
addrs_tensor
=
torch
.
tensor
(
ip_addr
+
model_port_args
.
model_tp_ports
,
dtype
=
torch
.
int
)
addrs_tensor
=
torch
.
tensor
(
ip_addr
+
model_port_args
.
model_tp_ports
,
dtype
=
torch
.
int
)
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
)
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
,
)
dist
.
send
(
addrs_tensor
,
dst
=
0
)
print
(
f
"Node
{
server_args
.
node_rank
}
sent: ip_address
{
ip_addr
}
and ports
{
model_port_args
.
model_tp_ports
}
"
)
print
(
f
"Node
{
server_args
.
node_rank
}
sent: ip_address
{
ip_addr
}
and ports
{
model_port_args
.
model_tp_ports
}
"
)
dist
.
barrier
()
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
receive_addrs
(
model_port_args
,
server_args
):
assert
server_args
.
node_rank
==
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
))
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
)
)
ip_addr
=
get_ip_address
(
ifname
)
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
)
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
,
)
for
src_rank
in
range
(
1
,
server_args
.
nnodes
):
tensor
=
torch
.
zeros
(
4
+
num_tp_ports
,
dtype
=
torch
.
int
)
dist
.
recv
(
tensor
,
src
=
src_rank
)
ip
=
"."
.
join
([
str
(
x
)
for
x
in
tensor
[:
4
].
tolist
()])
ports
=
tensor
[
4
:].
tolist
()
model_port_args
.
model_tp_ips
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)]
=
[
ip
]
*
num_tp_ports
model_port_args
.
model_tp_ports
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)]
=
ports
model_port_args
.
model_tp_ips
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)
]
=
[
ip
]
*
num_tp_ports
model_port_args
.
model_tp_ports
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)
]
=
ports
print
(
f
"Node 0 received from rank
{
src_rank
}
:
{
tensor
.
tolist
()
}
"
)
dist
.
barrier
()
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment