Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dc1b8bcf
Unverified
Commit
dc1b8bcf
authored
Jul 05, 2024
by
Ying Sheng
Committed by
GitHub
Jul 05, 2024
Browse files
Format (#593)
parent
5a57b8ad
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
484 additions
and
353 deletions
+484
-353
benchmark/latency_throughput/bench_one.py
benchmark/latency_throughput/bench_one.py
+1
-1
benchmark/latency_throughput/bench_serving.py
benchmark/latency_throughput/bench_serving.py
+3
-2
benchmark/line_retrieval/gen_data.py
benchmark/line_retrieval/gen_data.py
+3
-3
benchmark/mmlu/bench_sglang.py
benchmark/mmlu/bench_sglang.py
+9
-5
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+28
-12
python/sglang/global_config.py
python/sglang/global_config.py
+1
-0
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+4
-2
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+3
-2
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+7
-3
python/sglang/srt/layers/fused_moe.py
python/sglang/srt/layers/fused_moe.py
+181
-167
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+9
-2
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+1
-1
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+20
-11
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+31
-14
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+8
-8
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+74
-60
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+10
-7
python/sglang/srt/server.py
python/sglang/srt/server.py
+17
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-16
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+66
-29
No files found.
benchmark/latency_throughput/bench_one.py
View file @
dc1b8bcf
...
@@ -92,4 +92,4 @@ if __name__ == "__main__":
...
@@ -92,4 +92,4 @@ if __name__ == "__main__":
print
(
ret
)
print
(
ret
)
speed
=
args
.
batch_size
*
max_new_tokens
/
latency
speed
=
args
.
batch_size
*
max_new_tokens
/
latency
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
speed
:.
2
f
}
token/s"
)
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
speed
:.
2
f
}
token/s"
)
\ No newline at end of file
benchmark/latency_throughput/bench_serving.py
View file @
dc1b8bcf
...
@@ -307,8 +307,9 @@ def main(args: argparse.Namespace):
...
@@ -307,8 +307,9 @@ def main(args: argparse.Namespace):
avg_per_output_token_latency
=
np
.
mean
(
avg_per_output_token_latency
=
np
.
mean
(
[
latency
/
output_len
for
_
,
output_len
,
latency
in
REQUEST_LATENCY
]
[
latency
/
output_len
for
_
,
output_len
,
latency
in
REQUEST_LATENCY
]
)
)
decoding_throughput
=
np
.
sum
([
decoding_throughput
=
(
output_len
for
_
,
output_len
,
_
in
REQUEST_LATENCY
])
/
benchmark_time
np
.
sum
([
output_len
for
_
,
output_len
,
_
in
REQUEST_LATENCY
])
/
benchmark_time
)
print
(
f
"Total time:
{
benchmark_time
:.
2
f
}
s"
)
print
(
f
"Total time:
{
benchmark_time
:.
2
f
}
s"
)
print
(
f
"Request throughput:
{
args
.
num_prompts
/
benchmark_time
:.
2
f
}
requests/s"
)
print
(
f
"Request throughput:
{
args
.
num_prompts
/
benchmark_time
:.
2
f
}
requests/s"
)
...
...
benchmark/line_retrieval/gen_data.py
View file @
dc1b8bcf
...
@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
...
@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
)
)
for
i
in
redirect_indices
:
for
i
in
redirect_indices
:
target_idx
=
np
.
random
.
choice
(
min
(
i
*
2
+
100
,
num_lines
))
target_idx
=
np
.
random
.
choice
(
min
(
i
*
2
+
100
,
num_lines
))
lines
[
lines
[
i
]
=
(
i
f
"Line
{
indices
[
i
]
}
: The REGISTER_CONTENT is the same as Line
{
indices
[
target_idx
]
}
."
]
=
f
"Line
{
indices
[
i
]
}
: The REGISTER_CONTENT is the same as Line
{
indices
[
target_idx
]
}
."
)
redirects
[
i
]
=
target_idx
redirects
[
i
]
=
target_idx
# Build links and find sources
# Build links and find sources
...
...
benchmark/mmlu/bench_sglang.py
View file @
dc1b8bcf
...
@@ -80,10 +80,12 @@ def main(args):
...
@@ -80,10 +80,12 @@ def main(args):
for
i
in
range
(
test_df
.
shape
[
0
]):
for
i
in
range
(
test_df
.
shape
[
0
]):
prompt_end
=
format_example
(
test_df
,
i
,
include_answer
=
False
)
prompt_end
=
format_example
(
test_df
,
i
,
include_answer
=
False
)
arguments
.
append
({
arguments
.
append
(
"examples"
:
few_shot_examples
,
{
"question"
:
prompt_end
,
"examples"
:
few_shot_examples
,
})
"question"
:
prompt_end
,
}
)
label
=
test_df
.
iloc
[
i
,
test_df
.
shape
[
1
]
-
1
]
label
=
test_df
.
iloc
[
i
,
test_df
.
shape
[
1
]
-
1
]
labels
.
append
(
label
)
labels
.
append
(
label
)
...
@@ -134,7 +136,9 @@ def main(args):
...
@@ -134,7 +136,9 @@ def main(args):
pt
=
0
pt
=
0
for
subject
,
num_qs
in
zip
(
subjects
[:
args
.
nsub
],
num_questions
):
for
subject
,
num_qs
in
zip
(
subjects
[:
args
.
nsub
],
num_questions
):
print
(
f
"subject:
{
subject
}
, #q:
{
num_qs
}
, acc:
{
np
.
mean
(
cors
[
pt
:
pt
+
num_qs
]):.
3
f
}
"
)
print
(
f
"subject:
{
subject
}
, #q:
{
num_qs
}
, acc:
{
np
.
mean
(
cors
[
pt
:
pt
+
num_qs
]):.
3
f
}
"
)
pt
+=
num_qs
pt
+=
num_qs
assert
pt
==
len
(
cors
)
assert
pt
==
len
(
cors
)
weighted_acc
=
np
.
mean
(
cors
)
weighted_acc
=
np
.
mean
(
cors
)
...
...
python/sglang/bench_latency.py
View file @
dc1b8bcf
...
@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
...
@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
for
i
in
range
(
len
(
prompts
)):
for
i
in
range
(
len
(
prompts
)):
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
assert
len
(
input_ids
[
i
])
>
bench_args
.
cut_len
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
tmp_input_ids
=
input_ids
[
i
][:
bench_args
.
cut_len
]
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
sampling_params
=
sampling_params
...
@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
...
@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
def
prepare_extend_inputs
(
bench_args
,
input_ids
,
reqs
,
model_runner
):
def
prepare_extend_inputs
(
bench_args
,
input_ids
,
reqs
,
model_runner
):
for
i
in
range
(
len
(
reqs
)):
for
i
in
range
(
len
(
reqs
)):
req
=
reqs
[
i
]
req
=
reqs
[
i
]
req
.
input_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
req
.
input_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
req
.
prefix_indices
=
model_runner
.
req_to_token_pool
.
req_to_token
[
req
.
prefix_indices
=
model_runner
.
req_to_token_pool
.
req_to_token
[
i
,
:
bench_args
.
cut_len
i
,
:
bench_args
.
cut_len
]
]
return
reqs
return
reqs
...
@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
...
@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
tree_cache
=
None
)
tree_cache
=
None
,
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
,
None
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
,
None
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
...
@@ -212,7 +213,9 @@ def latency_test(
...
@@ -212,7 +213,9 @@ def latency_test(
# Load the model
# Load the model
model_runner
,
tokenizer
=
load_model
(
server_args
,
tp_rank
)
model_runner
,
tokenizer
=
load_model
(
server_args
,
tp_rank
)
print
(
f
"max_batch_size=
{
model_runner
.
max_total_num_tokens
//
(
bench_args
.
input_len
+
bench_args
.
output_len
)
}
"
)
print
(
f
"max_batch_size=
{
model_runner
.
max_total_num_tokens
//
(
bench_args
.
input_len
+
bench_args
.
output_len
)
}
"
)
# Prepare inputs
# Prepare inputs
reqs
=
prepare_synthetic_inputs
(
bench_args
,
tokenizer
)
reqs
=
prepare_synthetic_inputs
(
bench_args
,
tokenizer
)
...
@@ -232,7 +235,9 @@ def latency_test(
...
@@ -232,7 +235,9 @@ def latency_test(
prefill_latency
=
time
.
time
()
-
tic
prefill_latency
=
time
.
time
()
-
tic
tot_latency
+=
prefill_latency
tot_latency
+=
prefill_latency
throughput
=
bench_args
.
input_len
*
bench_args
.
batch_size
/
prefill_latency
throughput
=
bench_args
.
input_len
*
bench_args
.
batch_size
/
prefill_latency
rank_print
(
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
rank_print
(
f
"Prefill. latency:
{
prefill_latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
# Decode
# Decode
for
i
in
range
(
output_len
):
for
i
in
range
(
output_len
):
...
@@ -243,13 +248,24 @@ def latency_test(
...
@@ -243,13 +248,24 @@ def latency_test(
latency
=
time
.
time
()
-
tic
latency
=
time
.
time
()
-
tic
tot_latency
+=
latency
tot_latency
+=
latency
throughput
=
bench_args
.
batch_size
/
latency
throughput
=
bench_args
.
batch_size
/
latency
if
i
<
5
:
rank_print
(
f
"Decode. latency:
{
latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
if
i
<
5
:
rank_print
(
f
"Decode. latency:
{
latency
:
6.5
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
avg_decode_latency
=
(
tot_latency
-
prefill_latency
)
/
output_len
avg_decode_latency
=
(
tot_latency
-
prefill_latency
)
/
output_len
avg_decode_throughput
=
bench_args
.
batch_size
/
avg_decode_latency
avg_decode_throughput
=
bench_args
.
batch_size
/
avg_decode_latency
rank_print
(
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
)
rank_print
(
f
"Decode. avg latency:
{
avg_decode_latency
:
6.5
f
}
s, avg throughput:
{
avg_decode_throughput
:
9.2
f
}
token/s"
throughput
=
(
bench_args
.
input_len
+
bench_args
.
output_len
)
*
bench_args
.
batch_size
/
tot_latency
)
rank_print
(
f
"Total. latency:
{
tot_latency
:
6.3
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
throughput
=
(
(
bench_args
.
input_len
+
bench_args
.
output_len
)
*
bench_args
.
batch_size
/
tot_latency
)
rank_print
(
f
"Total. latency:
{
tot_latency
:
6.3
f
}
s, throughput:
{
throughput
:
9.2
f
}
token/s"
)
# Warm up
# Warm up
run_once
(
4
)
run_once
(
4
)
...
@@ -298,4 +314,4 @@ if __name__ == "__main__":
...
@@ -298,4 +314,4 @@ if __name__ == "__main__":
format
=
"%(message)s"
,
format
=
"%(message)s"
,
)
)
main
(
server_args
,
bench_args
)
main
(
server_args
,
bench_args
)
\ No newline at end of file
python/sglang/global_config.py
View file @
dc1b8bcf
...
@@ -39,4 +39,5 @@ class GlobalConfig:
...
@@ -39,4 +39,5 @@ class GlobalConfig:
# This can improve the speed for large batch sizes during prefill.
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
self
.
layer_sync_threshold
=
8192
global_config
=
GlobalConfig
()
global_config
=
GlobalConfig
()
python/sglang/lang/ir.py
View file @
dc1b8bcf
...
@@ -185,8 +185,10 @@ class SglFunction:
...
@@ -185,8 +185,10 @@ class SglFunction:
batch_kwargs
=
[
batch_kwargs
=
[
{
self
.
arg_names
[
i
]:
v
for
i
,
v
in
enumerate
(
arg_values
)}
{
self
.
arg_names
[
i
]:
v
for
i
,
v
in
enumerate
(
arg_values
)}
for
arg_values
in
batch_kwargs
for
arg_values
in
batch_kwargs
if
isinstance
(
arg_values
,
(
list
,
tuple
))
and
if
isinstance
(
arg_values
,
(
list
,
tuple
))
len
(
self
.
arg_names
)
-
len
(
self
.
arg_defaults
)
<=
len
(
arg_values
)
<=
len
(
self
.
arg_names
)
and
len
(
self
.
arg_names
)
-
len
(
self
.
arg_defaults
)
<=
len
(
arg_values
)
<=
len
(
self
.
arg_names
)
]
]
# Ensure to raise an exception if the number of arguments mismatch
# Ensure to raise an exception if the number of arguments mismatch
if
len
(
batch_kwargs
)
!=
num_programs
:
if
len
(
batch_kwargs
)
!=
num_programs
:
...
...
python/sglang/srt/constrained/__init__.py
View file @
dc1b8bcf
...
@@ -5,13 +5,14 @@ from pydantic import BaseModel
...
@@ -5,13 +5,14 @@ from pydantic import BaseModel
try
:
try
:
from
outlines.caching
import
cache
as
disk_cache
from
outlines.caching
import
cache
as
disk_cache
from
outlines.fsm.guide
import
RegexGuide
from
outlines.caching
import
disable_cache
from
outlines.caching
import
disable_cache
from
outlines.fsm.guide
import
RegexGuide
from
outlines.fsm.guide
import
RegexGuide
from
outlines.fsm.regex
import
FSMInfo
,
make_byte_level_fsm
,
make_deterministic_fsm
from
outlines.fsm.regex
import
FSMInfo
,
make_byte_level_fsm
,
make_deterministic_fsm
from
outlines.models.transformers
import
TransformerTokenizer
from
outlines.models.transformers
import
TransformerTokenizer
except
ImportError
as
e
:
except
ImportError
as
e
:
print
(
f
'
\n
Error:
{
e
}
. Please install a new version of outlines by `pip install "outlines>=0.0.44"`
\n
'
)
print
(
f
'
\n
Error:
{
e
}
. Please install a new version of outlines by `pip install "outlines>=0.0.44"`
\n
'
)
raise
raise
try
:
try
:
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
dc1b8bcf
...
@@ -264,7 +264,9 @@ class TiktokenTokenizer:
...
@@ -264,7 +264,9 @@ class TiktokenTokenizer:
return
self
.
tokenizer
.
decode_batch
(
batch
)
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
return
self
.
encode
(
ret
)
if
tokenize
else
ret
...
@@ -297,5 +299,7 @@ class SentencePieceTokenizer:
...
@@ -297,5 +299,7 @@ class SentencePieceTokenizer:
return
self
.
tokenizer
.
decode
(
batch
)
return
self
.
tokenizer
.
decode
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
ret
=
self
.
chat_template
.
render
(
return
self
.
encode
(
ret
)
if
tokenize
else
ret
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
\ No newline at end of file
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
python/sglang/srt/layers/fused_moe.py
View file @
dc1b8bcf
...
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
...
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -108,12 +107,16 @@ def fused_moe_kernel(
...
@@ -108,12 +107,16 @@ def fused_moe_kernel(
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
a_ptrs
=
a_ptr
+
(
offs_k
[
None
,
:]
*
stride_ak
)
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
b_ptrs
=
(
offs_bn
[
None
,
:]
*
stride_bn
)
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
)
if
use_fp8
:
if
use_fp8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
a_scale
=
tl
.
load
(
a_scale_ptr
)
...
@@ -129,13 +132,12 @@ def fused_moe_kernel(
...
@@ -129,13 +132,12 @@ def fused_moe_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# Load the next block of A and B, generate a mask by checking the
# K dimension.
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
a
=
tl
.
load
(
mask
=
token_mask
[:,
None
]
&
a_ptrs
,
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
other
=
0.0
,
b
=
tl
.
load
(
b_ptrs
,
)
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
other
=
0.0
)
# We accumulate along the K dimension.
# We accumulate along the K dimension.
if
use_fp8
:
if
use_fp8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
...
@@ -146,9 +148,7 @@ def fused_moe_kernel(
...
@@ -146,9 +148,7 @@ def fused_moe_kernel(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_fp8
:
if
use_fp8
:
...
@@ -158,15 +158,14 @@ def fused_moe_kernel(
...
@@ -158,15 +158,14 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# -----------------------------------------------------------
# Write back the block of the output
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
moe_align_block_size
(
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Aligns the token distribution across experts to be compatible with block
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
size for matrix multiplication.
...
@@ -205,32 +204,38 @@ def moe_align_block_size(
...
@@ -205,32 +204,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
by block_size for proper block matrix operations.
"""
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,
),
sorted_ids
=
torch
.
empty
(
dtype
=
torch
.
int32
,
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
device
=
topk_ids
.
device
)
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
expert_ids
=
torch
.
empty
(
dtype
=
torch
.
int32
,
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
device
=
topk_ids
.
device
)
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
dtype
=
torch
.
int32
,
ops
.
moe_align_block_size
(
device
=
topk_ids
.
device
)
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
)
expert_ids
,
num_tokens_post_pad
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
def
invoke_fused_moe_kernel
(
A_scale
:
Optional
[
torch
.
Tensor
],
A
:
torch
.
Tensor
,
B_scale
:
Optional
[
torch
.
Tensor
],
B
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
expert_ids
:
torch
.
Tensor
,
B_scale
:
Optional
[
torch
.
Tensor
],
num_tokens_post_padded
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
topk_ids
:
torch
.
Tensor
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
sorted_token_ids
:
torch
.
Tensor
,
use_fp8
:
bool
)
->
None
:
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
...
@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
assert
B_scale
is
not
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
grid
=
lambda
META
:
(
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
fused_moe_kernel
[
grid
](
fused_moe_kernel
[
grid
](
A
,
A
,
...
@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
...
@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@
functools
.
lru_cache
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
,
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
"""
"""
Return optimized configurations for the fused MoE kernel.
Return optimized configurations for the fused MoE kernel.
...
@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int,
...
@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int,
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
config_file_path
=
os
.
path
.
join
(
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
config_file_path
)
# If a configuration has been found, return it
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
...
@@ -319,35 +325,35 @@ def get_default_config(
...
@@ -319,35 +325,35 @@ def get_default_config(
)
->
Dict
[
str
,
int
]:
)
->
Dict
[
str
,
int
]:
if
dtype
==
"float8"
:
if
dtype
==
"float8"
:
config
=
{
config
=
{
'
BLOCK_SIZE_M
'
:
128
,
"
BLOCK_SIZE_M
"
:
128
,
'
BLOCK_SIZE_N
'
:
256
,
"
BLOCK_SIZE_N
"
:
256
,
'
BLOCK_SIZE_K
'
:
128
,
"
BLOCK_SIZE_K
"
:
128
,
'
GROUP_SIZE_M
'
:
32
,
"
GROUP_SIZE_M
"
:
32
,
"num_warps"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
4
"num_stages"
:
4
,
}
}
if
M
<=
E
:
if
M
<=
E
:
config
=
{
config
=
{
'
BLOCK_SIZE_M
'
:
64
,
"
BLOCK_SIZE_M
"
:
64
,
'
BLOCK_SIZE_N
'
:
128
,
"
BLOCK_SIZE_N
"
:
128
,
'
BLOCK_SIZE_K
'
:
128
,
"
BLOCK_SIZE_K
"
:
128
,
'
GROUP_SIZE_M
'
:
1
,
"
GROUP_SIZE_M
"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
"num_stages"
:
4
,
}
}
else
:
else
:
config
=
{
config
=
{
'
BLOCK_SIZE_M
'
:
64
,
"
BLOCK_SIZE_M
"
:
64
,
'
BLOCK_SIZE_N
'
:
64
,
"
BLOCK_SIZE_N
"
:
64
,
'
BLOCK_SIZE_K
'
:
32
,
"
BLOCK_SIZE_K
"
:
32
,
'
GROUP_SIZE_M
'
:
8
"
GROUP_SIZE_M
"
:
8
,
}
}
if
M
<=
E
:
if
M
<=
E
:
config
=
{
config
=
{
'
BLOCK_SIZE_M
'
:
16
,
"
BLOCK_SIZE_M
"
:
16
,
'
BLOCK_SIZE_N
'
:
32
,
"
BLOCK_SIZE_N
"
:
32
,
'
BLOCK_SIZE_K
'
:
64
,
"
BLOCK_SIZE_K
"
:
64
,
'
GROUP_SIZE_M
'
:
1
"
GROUP_SIZE_M
"
:
1
,
}
}
return
config
return
config
...
@@ -358,23 +364,17 @@ def fused_topk(
...
@@ -358,23 +364,17 @@ def fused_topk(
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
):
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
"Number of tokens mismatch"
)
M
,
_
=
hidden_states
.
shape
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk_weights
=
torch
.
empty
(
topk
,
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
dtype
=
torch
.
float32
,
)
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
token_expert_indicies
=
torch
.
empty
(
topk
,
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
dtype
=
torch
.
int32
,
)
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
ops
.
topk_softmax
(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
...
@@ -388,27 +388,27 @@ def fused_topk(
...
@@ -388,27 +388,27 @@ def fused_topk(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts
(
w1
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
topk_ids
:
torch
.
Tensor
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
inplace
:
bool
=
False
,
use_fp8
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp8
:
bool
=
False
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Check constraints.
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
...
@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor,
config
=
override_config
config
=
override_config
else
:
else
:
# First try to load optimal config from the file
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
"float8"
if
use_fp8
else
None
)
if
configs
:
if
configs
:
# If an optimal configuration map has been found, look up the
# If an optimal configuration map has been found, look up the
...
@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor,
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
else
:
# Else use the default config
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
config
=
get_default_config
(
topk_ids
.
shape
[
1
],
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
"float8"
if
use_fp8
else
None
)
)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
intermediate_cache1
=
torch
.
empty
(
device
=
hidden_states
.
device
,
(
M
,
topk_ids
.
shape
[
1
],
N
),
dtype
=
hidden_states
.
dtype
)
device
=
hidden_states
.
device
,
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
dtype
=
hidden_states
.
dtype
)
intermediate_cache2
=
torch
.
empty
(
intermediate_cache3
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
(
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache3
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
compute_type
=
(
tl
.
bfloat16
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
invoke_fused_moe_kernel
(
hidden_states
,
invoke_fused_moe_kernel
(
w1
,
hidden_states
,
intermediate_cache1
,
w1
,
a1_scale
,
intermediate_cache1
,
w1_scale
,
a1_scale
,
topk_weights
,
w1_scale
,
topk_ids
,
topk_weights
,
sorted_token_ids
,
topk_ids
,
expert_ids
,
sorted_token_ids
,
num_tokens_post_padded
,
expert_ids
,
False
,
num_tokens_post_padded
,
topk_ids
.
shape
[
1
],
False
,
config
,
topk_ids
.
shape
[
1
],
compute_type
=
compute_type
,
config
,
use_fp8
=
use_fp8
)
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
invoke_fused_moe_kernel
(
w2
,
intermediate_cache2
,
intermediate_cache3
,
w2
,
a2_scale
,
intermediate_cache3
,
w2_scale
,
a2_scale
,
topk_weights
,
w2_scale
,
topk_ids
,
topk_weights
,
sorted_token_ids
,
topk_ids
,
expert_ids
,
sorted_token_ids
,
num_tokens_post_padded
,
expert_ids
,
True
,
num_tokens_post_padded
,
1
,
True
,
config
,
1
,
compute_type
=
compute_type
,
config
,
use_fp8
=
use_fp8
)
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
if
inplace
:
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
return
torch
.
sum
(
dim
=
1
,
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out
=
hidden_states
)
dim
=
1
,
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out
=
hidden_states
,
dim
=
1
)
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
def
fused_moe
(
def
fused_moe
(
...
@@ -532,25 +542,28 @@ def fused_moe(
...
@@ -532,25 +542,28 @@ def fused_moe(
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
if
hasattr
(
ops
,
"topk_softmax"
):
if
hasattr
(
ops
,
"topk_softmax"
):
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
topk_weights
,
topk_ids
=
fused_topk
(
renormalize
)
hidden_states
,
gating_output
,
topk
,
renormalize
)
else
:
else
:
topk_weights
,
topk_ids
=
fused_topk_v0_4_3
(
hidden_states
,
gating_output
,
topk
,
topk_weights
,
topk_ids
=
fused_topk_v0_4_3
(
renormalize
)
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
w1
,
return
fused_experts
(
w2
,
hidden_states
,
topk_weights
,
w1
,
topk_ids
,
w2
,
inplace
=
inplace
,
topk_weights
,
override_config
=
override_config
,
topk_ids
,
use_fp8
=
use_fp8
,
inplace
=
inplace
,
w1_scale
=
w1_scale
,
override_config
=
override_config
,
w2_scale
=
w2_scale
,
use_fp8
=
use_fp8
,
a1_scale
=
a1_scale
,
w1_scale
=
w1_scale
,
a2_scale
=
a2_scale
)
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
def
fused_topk_v0_4_3
(
def
fused_topk_v0_4_3
(
...
@@ -560,6 +573,7 @@ def fused_topk_v0_4_3(
...
@@ -560,6 +573,7 @@ def fused_topk_v0_4_3(
renormalize
:
bool
,
renormalize
:
bool
,
):
):
import
vllm._moe_C
as
moe_kernels
import
vllm._moe_C
as
moe_kernels
M
,
_
=
hidden_states
.
shape
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
topk_weights
=
torch
.
empty
(
...
@@ -579,4 +593,4 @@ def fused_topk_v0_4_3(
...
@@ -579,4 +593,4 @@ def fused_topk_v0_4_3(
if
renormalize
:
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
\ No newline at end of file
python/sglang/srt/layers/radix_attention.py
View file @
dc1b8bcf
"""Radix attention."""
"""Radix attention."""
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -11,8 +12,13 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
...
@@ -11,8 +12,13 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
num_heads
:
int
,
head_dim
:
int
,
scaling
:
float
,
num_kv_heads
:
int
,
self
,
layer_id
:
int
,
logit_cap
:
int
=
-
1
num_heads
:
int
,
head_dim
:
int
,
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
logit_cap
:
int
=
-
1
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_q_head_num
=
num_heads
...
@@ -112,6 +118,7 @@ class RadixAttention(nn.Module):
...
@@ -112,6 +118,7 @@ class RadixAttention(nn.Module):
)
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.cascade
import
merge_state
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
dc1b8bcf
...
@@ -99,4 +99,4 @@ def start_controller_process(
...
@@ -99,4 +99,4 @@ def start_controller_process(
except
Exception
:
except
Exception
:
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
finally
:
finally
:
kill_parent_process
()
kill_parent_process
()
\ No newline at end of file
python/sglang/srt/managers/controller/model_runner.py
View file @
dc1b8bcf
...
@@ -127,7 +127,7 @@ class InputMetadata:
...
@@ -127,7 +127,7 @@ class InputMetadata:
num_qo_heads
,
num_qo_heads
,
num_kv_heads
,
num_kv_heads
,
head_dim
,
head_dim
,
1
1
,
)
)
else
:
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
end_forward
()
...
@@ -140,7 +140,7 @@ class InputMetadata:
...
@@ -140,7 +140,7 @@ class InputMetadata:
head_dim
,
head_dim
,
1
,
1
,
pos_encoding_mode
=
"NONE"
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
)
def
init_extend_args
(
self
):
def
init_extend_args
(
self
):
...
@@ -228,7 +228,7 @@ class InputMetadata:
...
@@ -228,7 +228,7 @@ class InputMetadata:
ret
.
init_flashinfer_args
(
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
head_dim
model_runner
.
model_config
.
head_dim
,
)
)
return
ret
return
ret
...
@@ -269,7 +269,7 @@ class ModelRunner:
...
@@ -269,7 +269,7 @@ class ModelRunner:
world_size
=
self
.
tp_size
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
nccl_init_method
distributed_init_method
=
nccl_init_method
,
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
total_gpu_memory
=
get_available_gpu_memory
(
total_gpu_memory
=
get_available_gpu_memory
(
...
@@ -341,7 +341,13 @@ class ModelRunner:
...
@@ -341,7 +341,13 @@ class ModelRunner:
)
)
head_dim
=
self
.
model_config
.
head_dim
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
cell_size
=
(
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
)
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
1
-
self
.
mem_fraction_static
)
)
...
@@ -384,15 +390,16 @@ class ModelRunner:
...
@@ -384,15 +390,16 @@ class ModelRunner:
def
init_flash_infer
(
self
):
def
init_flash_infer
(
self
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
from
flashinfer
import
(
from
flashinfer
import
(
BatchPrefillWithRaggedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchDecodeWithPagedKVCacheWrapper
,
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)):
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
):
use_tensor_cores
=
True
use_tensor_cores
=
True
else
:
else
:
use_tensor_cores
=
False
use_tensor_cores
=
False
...
@@ -400,8 +407,8 @@ class ModelRunner:
...
@@ -400,8 +407,8 @@ class ModelRunner:
workspace_buffers
=
torch
.
empty
(
workspace_buffers
=
torch
.
empty
(
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_prefill_wrapper_ragged
=
(
workspace_buffers
[
0
],
"NHD"
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
)
)
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
1
],
"NHD"
workspace_buffers
[
1
],
"NHD"
...
@@ -410,7 +417,9 @@ class ModelRunner:
...
@@ -410,7 +417,9 @@ class ModelRunner:
workspace_buffers
[
2
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
workspace_buffers
[
2
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
)
else
:
else
:
self
.
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_paged
=
None
self
.
flashinfer_prefill_wrapper_ragged
=
(
self
.
flashinfer_prefill_wrapper_paged
)
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_decode_wrapper
=
None
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
dc1b8bcf
...
@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
...
@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.server_args
import
ModelPortArgs
,
ServerArgs
from
sglang.srt.server_args
import
ModelPortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
connect_rpyc_service
,
get_int_token_logit_bias
,
get_int_token_logit_bias
,
is_multimodal_model
,
is_multimodal_model
,
set_random_seed
,
set_random_seed
,
start_rpyc_service_process
,
start_rpyc_service_process
,
connect_rpyc_service
,
suppress_other_loggers
,
suppress_other_loggers
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -368,9 +368,11 @@ class ModelTpServer:
...
@@ -368,9 +368,11 @@ class ModelTpServer:
if
(
if
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
<
available_size
<
available_size
and
(
req
.
extend_input_len
+
new_batch_input_tokens
and
(
<=
self
.
max_prefill_tokens
req
.
extend_input_len
+
new_batch_input_tokens
or
len
(
can_run_list
)
==
0
)
<=
self
.
max_prefill_tokens
or
len
(
can_run_list
)
==
0
)
):
):
delta
=
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
delta
=
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
available_size
+=
delta
available_size
+=
delta
...
@@ -452,7 +454,9 @@ class ModelTpServer:
...
@@ -452,7 +454,9 @@ class ModelTpServer:
next_token_ids
,
next_token_ids
,
].
tolist
()
].
tolist
()
output
.
prefill_token_logprobs
=
output
.
prefill_token_logprobs
.
tolist
()
output
.
prefill_token_logprobs
=
output
.
prefill_token_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
output
.
normalized_prompt_logprobs
.
tolist
()
output
.
normalized_prompt_logprobs
=
(
output
.
normalized_prompt_logprobs
.
tolist
()
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
else
:
...
@@ -582,7 +586,9 @@ class ModelTpServer:
...
@@ -582,7 +586,9 @@ class ModelTpServer:
req
.
check_finished
()
req
.
check_finished
()
if
req
.
return_logprob
:
if
req
.
return_logprob
:
req
.
decode_token_logprobs
.
append
((
next_token_logprobs
[
i
],
next_token_id
))
req
.
decode_token_logprobs
.
append
(
(
next_token_logprobs
[
i
],
next_token_id
)
)
if
req
.
top_logprobs_num
>
0
:
if
req
.
top_logprobs_num
>
0
:
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
req
.
decode_top_logprobs
.
append
(
output
.
decode_top_logprobs
[
i
])
...
@@ -759,16 +765,27 @@ class ModelTpClient:
...
@@ -759,16 +765,27 @@ class ModelTpClient:
with
ThreadPoolExecutor
(
self
.
tp_size
)
as
executor
:
with
ThreadPoolExecutor
(
self
.
tp_size
)
as
executor
:
# Launch model processes
# Launch model processes
if
server_args
.
nnodes
==
1
:
if
server_args
.
nnodes
==
1
:
self
.
procs
=
list
(
executor
.
map
(
self
.
procs
=
list
(
lambda
args
:
start_rpyc_service_process
(
*
args
),
executor
.
map
(
[(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
lambda
args
:
start_rpyc_service_process
(
*
args
),
))
[
(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
)
)
addrs
=
[(
"localhost"
,
p
)
for
p
in
model_port_args
.
model_tp_ports
]
addrs
=
[(
"localhost"
,
p
)
for
p
in
model_port_args
.
model_tp_ports
]
else
:
else
:
addrs
=
[(
ip
,
port
)
for
ip
,
port
in
zip
(
model_port_args
.
model_tp_ips
,
model_port_args
.
model_tp_ports
)]
addrs
=
[
(
ip
,
port
)
self
.
model_services
=
list
(
executor
.
map
(
for
ip
,
port
in
zip
(
lambda
args
:
connect_rpyc_service
(
*
args
),
addrs
))
model_port_args
.
model_tp_ips
,
model_port_args
.
model_tp_ports
)
]
self
.
model_services
=
list
(
executor
.
map
(
lambda
args
:
connect_rpyc_service
(
*
args
),
addrs
)
)
# Init model
# Init model
def
init_model
(
i
):
def
init_model
(
i
):
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
dc1b8bcf
...
@@ -334,15 +334,15 @@ class TokenizerManager:
...
@@ -334,15 +334,15 @@ class TokenizerManager:
ret
[
"meta_info"
][
"decode_token_logprobs"
],
return_text_in_logprobs
ret
[
"meta_info"
][
"decode_token_logprobs"
],
return_text_in_logprobs
)
)
if
top_logprobs_num
>
0
:
if
top_logprobs_num
>
0
:
ret
[
"meta_info"
][
ret
[
"meta_info"
][
"prefill_top_logprobs"
]
=
(
"prefill
_top_logprobs
"
self
.
detokenize
_top_logprobs
_tokens
(
]
=
self
.
detokenize_top_logprobs_tokens
(
ret
[
"meta_info"
][
"prefill_top_logprobs"
],
return_text_in_logprobs
ret
[
"meta_info"
][
"prefill_top_logprobs"
],
return_text_in_logprobs
)
)
)
ret
[
"meta_info"
][
ret
[
"meta_info"
][
"decode_top_logprobs"
]
=
(
"decod
e_top_logprobs
"
self
.
detokeniz
e_top_logprobs
_tokens
(
]
=
self
.
detokenize_top_logprobs_tokens
(
ret
[
"meta_info"
][
"decode_top_logprobs"
],
return_text_in_logprobs
ret
[
"meta_info"
][
"decode_top_logprobs"
],
return_text_in_logprobs
)
)
)
return
ret
return
ret
...
...
python/sglang/srt/models/gemma2.py
View file @
dc1b8bcf
...
@@ -5,19 +5,23 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
...
@@ -5,19 +5,23 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Gemma2Config
from
transformers
import
Gemma2Config
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
# FIXME: temporary solution, remove after next vllm release
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.activation
import
GeluAndMul
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
)
QKVParallelLinear
,
from
vllm.model_executor.layers.quantization.base_config
import
(
RowParallelLinear
,
QuantizationConfig
)
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -26,8 +30,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -26,8 +30,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
# FIXME: temporary solution, remove after next vllm release
from
vllm.model_executor.custom_op
import
CustomOp
class
GemmaRMSNorm
(
CustomOp
):
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
"""RMS normalization for Gemma.
...
@@ -76,13 +78,19 @@ class GemmaRMSNorm(CustomOp):
...
@@ -76,13 +78,19 @@ class GemmaRMSNorm(CustomOp):
# FIXME: temporary solution, remove after next vllm release
# FIXME: temporary solution, remove after next vllm release
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
class
GemmaRotaryEmbedding
(
RotaryEmbedding
):
class
GemmaRotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq
=
1.0
/
(
base
**
(
inv_freq
=
1.0
/
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
int64
).
float
()
/
base
self
.
rotary_dim
))
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
int64
).
float
()
/
self
.
rotary_dim
)
)
return
inv_freq
return
inv_freq
...
@@ -98,18 +106,17 @@ class Gemma2MLP(nn.Module):
...
@@ -98,18 +106,17 @@ class Gemma2MLP(nn.Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
bias
=
False
,
)
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
hidden_size
,
)
bias
=
False
,
quant_config
=
quant_config
)
if
not
(
hidden_act
==
hidden_activation
==
"gelu_pytorch_tanh"
):
if
not
(
hidden_act
==
hidden_activation
==
"gelu_pytorch_tanh"
):
raise
ValueError
(
raise
ValueError
(
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to "
"function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
"`gelu_pytorch_tanh`."
)
self
.
act_fn
=
GeluAndMul
(
approximate
=
"tanh"
)
self
.
act_fn
=
GeluAndMul
(
approximate
=
"tanh"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -121,17 +128,19 @@ class Gemma2MLP(nn.Module):
...
@@ -121,17 +128,19 @@ class Gemma2MLP(nn.Module):
class
Gemma2Attention
(
nn
.
Module
):
class
Gemma2Attention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
layer_idx
:
int
,
self
,
config
:
Gemma2Config
,
layer_idx
:
int
,
hidden_size
:
int
,
config
:
Gemma2Config
,
num_heads
:
int
,
hidden_size
:
int
,
num_kv_heads
:
int
,
num_heads
:
int
,
head_dim
:
int
,
num_kv_heads
:
int
,
max_position_embeddings
:
int
,
head_dim
:
int
,
rope_theta
:
float
,
max_position_embeddings
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
rope_theta
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
config
=
config
...
@@ -183,15 +192,16 @@ class Gemma2Attention(nn.Module):
...
@@ -183,15 +192,16 @@ class Gemma2Attention(nn.Module):
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
# all layers.
use_sliding_window
=
(
layer_idx
%
2
==
1
use_sliding_window
=
layer_idx
%
2
==
1
and
config
.
sliding_window
is
not
None
and
config
.
sliding_window
is
not
None
)
del
use_sliding_window
# Unused.
del
use_sliding_window
# Unused.
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
attn
=
RadixAttention
(
self
.
head_dim
,
self
.
num_heads
,
self
.
scaling
,
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
self
.
scaling
,
layer_id
=
layer_idx
,
num_kv_heads
=
self
.
num_kv_heads
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
)
layer_id
=
layer_idx
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -238,14 +248,16 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -238,14 +248,16 @@ class Gemma2DecoderLayer(nn.Module):
hidden_activation
=
config
.
hidden_activation
,
hidden_activation
=
config
.
hidden_activation
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
input_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
GemmaRMSNorm
(
self
.
post_attention_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
eps
=
config
.
rms_norm_eps
)
)
self
.
pre_feedforward_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
self
.
pre_feedforward_layernorm
=
GemmaRMSNorm
(
eps
=
config
.
rms_norm_eps
)
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
self
.
post_feedforward_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
)
eps
=
config
.
rms_norm_eps
)
self
.
post_feedforward_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -258,8 +270,7 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -258,8 +270,7 @@ class Gemma2DecoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -268,7 +279,8 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -268,7 +279,8 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
,
residual
=
self
.
pre_feedforward_layernorm
(
hidden_states
,
residual
=
self
.
pre_feedforward_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_feedforward_layernorm
(
hidden_states
)
hidden_states
=
self
.
post_feedforward_layernorm
(
hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -289,10 +301,12 @@ class Gemma2Model(nn.Module):
...
@@ -289,10 +301,12 @@ class Gemma2Model(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
(
Gemma2DecoderLayer
(
layer_idx
,
config
,
cache_config
,
quant_config
)
[
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
Gemma2DecoderLayer
(
layer_idx
,
config
,
cache_config
,
quant_config
)
])
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# Normalize the embedding by sqrt(hidden_size)
# Normalize the embedding by sqrt(hidden_size)
...
@@ -392,7 +406,7 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -392,7 +406,7 @@ class Gemma2ForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
if
shard_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
name
=
name
.
replace
(
shard_name
,
param_name
)
...
@@ -412,8 +426,7 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -412,8 +426,7 @@ class Gemma2ForCausalLM(nn.Module):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
...
@@ -421,7 +434,8 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -421,7 +434,8 @@ class Gemma2ForCausalLM(nn.Module):
if
unloaded_params
:
if
unloaded_params
:
raise
RuntimeError
(
raise
RuntimeError
(
"Some weights are not initialized from checkpoints: "
"Some weights are not initialized from checkpoints: "
f
"
{
unloaded_params
}
"
)
f
"
{
unloaded_params
}
"
)
EntryClass
=
Gemma2ForCausalLM
EntryClass
=
Gemma2ForCausalLM
\ No newline at end of file
python/sglang/srt/models/llama_classification.py
View file @
dc1b8bcf
...
@@ -5,14 +5,12 @@ import tqdm
...
@@ -5,14 +5,12 @@ import tqdm
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
get_tensor_model_parallel_rank
get_tensor_model_parallel_rank
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.model_runner
import
InputMetadata
from
sglang.srt.models.llama2
import
LlamaModel
from
sglang.srt.models.llama2
import
LlamaModel
...
@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
...
@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
classification_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
classification_out_size
)
self
.
classification_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
classification_out_size
)
self
.
eos_token_id
=
config
.
eos_token_id
self
.
eos_token_id
=
config
.
eos_token_id
def
forward
(
def
forward
(
...
@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
...
@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
if
scores
.
shape
[
0
]
!=
input_metadata
.
batch_size
:
if
scores
.
shape
[
0
]
!=
input_metadata
.
batch_size
:
print
(
"Warning: the EOS tokens are missing in some sentences."
)
print
(
"Warning: the EOS tokens are missing in some sentences."
)
scores
=
torch
.
ones
((
input_metadata
.
batch_size
,
self
.
config
.
classification_out_size
)).
to
(
input_ids
.
device
)
scores
=
torch
.
ones
(
(
input_metadata
.
batch_size
,
self
.
config
.
classification_out_size
)
).
to
(
input_ids
.
device
)
return
LogitProcessorOutput
(
return
LogitProcessorOutput
(
next_token_logits
=
scores
,
next_token_logits
=
scores
,
...
@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
...
@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
LlamaForClassification
\ No newline at end of file
EntryClass
=
LlamaForClassification
python/sglang/srt/server.py
View file @
dc1b8bcf
...
@@ -51,13 +51,12 @@ from sglang.srt.utils import (
...
@@ -51,13 +51,12 @@ from sglang.srt.utils import (
allocate_init_ports
,
allocate_init_ports
,
assert_pkg_version
,
assert_pkg_version
,
enable_show_time_cost
,
enable_show_time_cost
,
send_addrs_to_rank_0
,
receive_addrs
,
receive_addrs
,
send_addrs_to_rank_0
,
start_rpyc_service_process
,
start_rpyc_service_process
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
server_args
.
disable_disk_cache
:
if
server_args
.
disable_disk_cache
:
disable_cache
()
disable_cache
()
if
not
server_args
.
disable_flashinfer
:
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.0.8"
,
"Please uninstall the old version and "
assert_pkg_version
(
"reinstall the latest version by following the instructions "
"flashinfer"
,
"at https://docs.flashinfer.ai/installation.html."
)
"0.0.8"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
)
if
server_args
.
chat_template
:
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
...
@@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
ModelPortArgs
(
ModelPortArgs
(
nccl_port
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)],
nccl_port
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)],
model_tp_ips
=
[
None
]
*
tp_size_local
,
model_tp_ips
=
[
None
]
*
tp_size_local
,
model_tp_ports
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp_size_local
+
1
)],
model_tp_ports
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp_size_local
+
1
)
],
)
)
)
)
port_args
=
PortArgs
(
port_args
=
PortArgs
(
...
@@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
else
:
else
:
receive_addrs
(
model_port_args
[
0
],
server_args
)
receive_addrs
(
model_port_args
[
0
],
server_args
)
for
i
in
range
(
tp_size_local
):
for
i
in
range
(
tp_size_local
):
start_rpyc_service_process
(
ModelTpService
,
model_port_args
[
0
].
model_tp_ports
[
i
])
start_rpyc_service_process
(
ModelTpService
,
model_port_args
[
0
].
model_tp_ports
[
i
]
)
if
server_args
.
node_rank
!=
0
:
if
server_args
.
node_rank
!=
0
:
logger
.
info
(
f
"[node_rank=
{
server_args
.
node_rank
}
]: Listen for connections..."
)
logger
.
info
(
f
"[node_rank=
{
server_args
.
node_rank
}
]: Listen for connections..."
)
while
True
:
while
True
:
pass
pass
...
...
python/sglang/srt/server_args.py
View file @
dc1b8bcf
...
@@ -137,17 +137,16 @@ class ServerArgs:
...
@@ -137,17 +137,16 @@ class ServerArgs:
"--dtype"
,
"--dtype"
,
type
=
str
,
type
=
str
,
default
=
ServerArgs
.
dtype
,
default
=
ServerArgs
.
dtype
,
choices
=
[
choices
=
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
],
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
help
=
"Data type for model weights and activations.
\n\n
"
],
help
=
'Data type for model weights and activations.
\n\n
'
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'
BF16 precision for BF16 models.
\n
'
"
BF16 precision for BF16 models.
\n
"
'* "half" for FP16. Recommended for AWQ quantization.
\n
'
'* "half" for FP16. Recommended for AWQ quantization.
\n
'
'* "float16" is the same as "half".
\n
'
'* "float16" is the same as "half".
\n
'
'* "bfloat16" for a balance between precision and range.
\n
'
'* "bfloat16" for a balance between precision and range.
\n
'
'* "float" is shorthand for FP32 precision.
\n
'
'* "float" is shorthand for FP32 precision.
\n
'
'* "float32" for FP32 precision.'
)
'* "float32" for FP32 precision.'
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--trust-remote-code"
,
"--trust-remote-code"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -271,19 +270,12 @@ class ServerArgs:
...
@@ -271,19 +270,12 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--nccl-init-addr"
,
"--nccl-init-addr"
,
type
=
str
,
type
=
str
,
help
=
"The nccl init address of multi-node server."
help
=
"The nccl init address of multi-node server."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--nnodes"
,
"--nnodes"
,
type
=
int
,
default
=
1
,
help
=
"The number of nodes."
type
=
int
,
default
=
1
,
help
=
"The number of nodes."
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
)
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
)
# Optimization/debug options
# Optimization/debug options
parser
.
add_argument
(
parser
.
add_argument
(
...
...
python/sglang/srt/utils.py
View file @
dc1b8bcf
...
@@ -432,13 +432,12 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
...
@@ -432,13 +432,12 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
if
pkg_version
.
parse
(
installed_version
)
<
pkg_version
.
parse
(
min_version
):
if
pkg_version
.
parse
(
installed_version
)
<
pkg_version
.
parse
(
min_version
):
raise
Exception
(
raise
Exception
(
f
"
{
pkg
}
is installed with version
{
installed_version
}
, which "
f
"
{
pkg
}
is installed with version
{
installed_version
}
, which "
f
"is less than the minimum required version
{
min_version
}
. "
+
f
"is less than the minimum required version
{
min_version
}
. "
+
message
message
)
)
except
PackageNotFoundError
:
except
PackageNotFoundError
:
raise
Exception
(
raise
Exception
(
f
"
{
pkg
}
with minimum required version
{
min_version
}
is not installed. "
+
f
"
{
pkg
}
with minimum required version
{
min_version
}
is not installed. "
message
+
message
)
)
...
@@ -474,24 +473,40 @@ def monkey_patch_vllm_dummy_weight_loader():
...
@@ -474,24 +473,40 @@ def monkey_patch_vllm_dummy_weight_loader():
"""
"""
from
vllm.model_executor.model_loader.loader
import
(
from
vllm.model_executor.model_loader.loader
import
(
ModelConfig
,
DeviceConfig
,
LoRAConfig
,
VisionLanguageConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
,
CacheConfig
,
nn
,
DeviceConfig
,
set_default_torch_dtype
,
_initialize_model
,
initialize_dummy_weights
,
DummyModelLoader
,
DummyModelLoader
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
,
_initialize_model
,
initialize_dummy_weights
,
nn
,
set_default_torch_dtype
,
)
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
device_config
:
DeviceConfig
,
self
,
lora_config
:
Optional
[
LoRAConfig
],
*
,
vision_language_config
:
Optional
[
VisionLanguageConfig
],
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
device_config
:
DeviceConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
cache_config
:
CacheConfig
)
->
nn
.
Module
:
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
lora_config
,
vision_language_config
,
model_config
,
cache_config
)
self
.
load_config
,
lora_config
,
vision_language_config
,
cache_config
,
)
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
...
@@ -541,7 +556,7 @@ def get_ip_address(ifname):
...
@@ -541,7 +556,7 @@ def get_ip_address(ifname):
ip_address
=
fcntl
.
ioctl
(
ip_address
=
fcntl
.
ioctl
(
s
.
fileno
(),
s
.
fileno
(),
0x8915
,
# SIOCGIFADDR
0x8915
,
# SIOCGIFADDR
struct
.
pack
(
'
256s
'
,
bytes
(
ifname
[:
15
],
'
utf-8
'
))
struct
.
pack
(
"
256s
"
,
bytes
(
ifname
[:
15
],
"
utf-8
"
))
,
)[
20
:
24
]
)[
20
:
24
]
return
socket
.
inet_ntoa
(
ip_address
)
return
socket
.
inet_ntoa
(
ip_address
)
...
@@ -550,44 +565,66 @@ def send_addrs_to_rank_0(model_port_args, server_args):
...
@@ -550,44 +565,66 @@ def send_addrs_to_rank_0(model_port_args, server_args):
assert
server_args
.
node_rank
!=
0
and
server_args
.
dp_size
==
1
assert
server_args
.
node_rank
!=
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
))
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
)
)
ip_addr
=
get_ip_address
(
ifname
)
ip_addr
=
get_ip_address
(
ifname
)
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
ip_addr
=
[
int
(
x
)
for
x
in
ip_addr
.
split
(
"."
)]
ip_addr
=
[
int
(
x
)
for
x
in
ip_addr
.
split
(
"."
)]
addrs_tensor
=
torch
.
tensor
(
ip_addr
+
model_port_args
.
model_tp_ports
,
dtype
=
torch
.
int
)
addrs_tensor
=
torch
.
tensor
(
ip_addr
+
model_port_args
.
model_tp_ports
,
dtype
=
torch
.
int
)
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
)
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
,
)
dist
.
send
(
addrs_tensor
,
dst
=
0
)
dist
.
send
(
addrs_tensor
,
dst
=
0
)
print
(
f
"Node
{
server_args
.
node_rank
}
sent: ip_address
{
ip_addr
}
and ports
{
model_port_args
.
model_tp_ports
}
"
)
print
(
f
"Node
{
server_args
.
node_rank
}
sent: ip_address
{
ip_addr
}
and ports
{
model_port_args
.
model_tp_ports
}
"
)
dist
.
barrier
()
dist
.
barrier
()
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
receive_addrs
(
model_port_args
,
server_args
):
def
receive_addrs
(
model_port_args
,
server_args
):
assert
server_args
.
node_rank
==
0
and
server_args
.
dp_size
==
1
assert
server_args
.
node_rank
==
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
))
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
)
)
ip_addr
=
get_ip_address
(
ifname
)
ip_addr
=
get_ip_address
(
ifname
)
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
)
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
,
)
for
src_rank
in
range
(
1
,
server_args
.
nnodes
):
for
src_rank
in
range
(
1
,
server_args
.
nnodes
):
tensor
=
torch
.
zeros
(
4
+
num_tp_ports
,
dtype
=
torch
.
int
)
tensor
=
torch
.
zeros
(
4
+
num_tp_ports
,
dtype
=
torch
.
int
)
dist
.
recv
(
tensor
,
src
=
src_rank
)
dist
.
recv
(
tensor
,
src
=
src_rank
)
ip
=
"."
.
join
([
str
(
x
)
for
x
in
tensor
[:
4
].
tolist
()])
ip
=
"."
.
join
([
str
(
x
)
for
x
in
tensor
[:
4
].
tolist
()])
ports
=
tensor
[
4
:].
tolist
()
ports
=
tensor
[
4
:].
tolist
()
model_port_args
.
model_tp_ips
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)]
=
[
ip
]
*
num_tp_ports
model_port_args
.
model_tp_ips
[
model_port_args
.
model_tp_ports
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)]
=
ports
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)
]
=
[
ip
]
*
num_tp_ports
model_port_args
.
model_tp_ports
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)
]
=
ports
print
(
f
"Node 0 received from rank
{
src_rank
}
:
{
tensor
.
tolist
()
}
"
)
print
(
f
"Node 0 received from rank
{
src_rank
}
:
{
tensor
.
tolist
()
}
"
)
dist
.
barrier
()
dist
.
barrier
()
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment