Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb9296f0
"android/vscode:/vscode.git/clone" did not exist on "e1c49fafa7e077c85b4a8cfe1e18ccddeb853959"
Unverified
Commit
fb9296f0
authored
Jun 12, 2024
by
Ying Sheng
Committed by
GitHub
Jun 12, 2024
Browse files
Higher priority for user input of max_prefill_tokens & format (#540)
parent
1374334d
Changes
50
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
227 additions
and
172 deletions
+227
-172
benchmark/gsm8k/bench_other.py
benchmark/gsm8k/bench_other.py
+1
-1
benchmark/latency_throughput/bench_throughput.py
benchmark/latency_throughput/bench_throughput.py
+24
-10
benchmark/mmlu/bench_other.py
benchmark/mmlu/bench_other.py
+1
-1
python/sglang/__init__.py
python/sglang/__init__.py
+1
-1
python/sglang/backend/litellm.py
python/sglang/backend/litellm.py
+2
-1
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+26
-15
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+1
-1
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+2
-4
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-1
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+1
-0
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+1
-1
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+1
-0
python/sglang/srt/constrained/jump_forward.py
python/sglang/srt/constrained/jump_forward.py
+2
-1
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+1
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+23
-8
python/sglang/srt/layers/fused_moe.py
python/sglang/srt/layers/fused_moe.py
+125
-119
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-0
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+5
-2
python/sglang/srt/managers/controller/dp_worker.py
python/sglang/srt/managers/controller/dp_worker.py
+5
-3
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+3
-3
No files found.
benchmark/gsm8k/bench_other.py
View file @
fb9296f0
...
...
@@ -65,7 +65,7 @@ def main(args):
def
get_one_answer
(
i
):
answer
=
call_generate
(
prompt
=
few_shot_examples
+
questions
[
i
],
#prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
#
prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
temperature
=
0
,
max_tokens
=
256
,
stop
=
"Question"
,
...
...
benchmark/latency_throughput/bench_throughput.py
View file @
fb9296f0
...
...
@@ -158,7 +158,9 @@ async def send_request(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
while
True
:
async
with
session
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
)
as
response
:
async
with
session
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
)
as
response
:
chunks
=
[]
async
for
chunk
,
_
in
response
.
content
.
iter_chunks
():
chunks
.
append
(
chunk
)
...
...
@@ -228,19 +230,32 @@ def main(args: argparse.Namespace):
np
.
random
.
seed
(
args
.
seed
)
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
dataset
:
input_requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
else
:
input_lens
=
np
.
random
.
randint
(
int
(
args
.
input_len
*
args
.
range_ratio
),
args
.
input_len
+
1
,
size
=
args
.
num_prompts
)
int
(
args
.
input_len
*
args
.
range_ratio
),
args
.
input_len
+
1
,
size
=
args
.
num_prompts
,
)
output_lens
=
np
.
random
.
randint
(
int
(
args
.
output_len
*
args
.
range_ratio
),
args
.
output_len
+
1
,
size
=
args
.
num_prompts
)
int
(
args
.
output_len
*
args
.
range_ratio
),
args
.
output_len
+
1
,
size
=
args
.
num_prompts
,
)
offsets
=
np
.
random
.
randint
(
0
,
tokenizer
.
vocab_size
,
size
=
args
.
num_prompts
)
input_requests
=
[]
for
i
in
range
(
args
.
num_prompts
):
prompt
=
tokenizer
.
decode
([(
offsets
[
i
]
+
i
+
j
)
%
tokenizer
.
vocab_size
for
j
in
range
(
input_lens
[
i
])])
prompt
=
tokenizer
.
decode
(
[
(
offsets
[
i
]
+
i
+
j
)
%
tokenizer
.
vocab_size
for
j
in
range
(
input_lens
[
i
])
]
)
input_requests
.
append
((
prompt
,
int
(
input_lens
[
i
]),
int
(
output_lens
[
i
])))
benchmark_start_time
=
time
.
perf_counter
()
...
...
@@ -287,16 +302,15 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
2048
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--range-ratio"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
"--tokenizer"
,
type
=
str
,
default
=
"NousResearch/Meta-Llama-3-8B"
,
help
=
"Name or path of the tokenizer."
help
=
"Name or path of the tokenizer."
,
)
parser
.
add_argument
(
"--best-of"
,
...
...
benchmark/mmlu/bench_other.py
View file @
fb9296f0
...
...
@@ -170,4 +170,4 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--data_dir"
,
type
=
str
,
default
=
"data"
)
parser
.
add_argument
(
"--nsub"
,
type
=
int
,
default
=
60
)
args
=
add_common_other_args_and_parse
(
parser
)
main
(
args
)
\ No newline at end of file
main
(
args
)
python/sglang/__init__.py
View file @
fb9296f0
...
...
@@ -24,10 +24,10 @@ from sglang.api import (
# SGL Backends
from
sglang.backend.anthropic
import
Anthropic
from
sglang.backend.litellm
import
LiteLLM
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.backend.vertexai
import
VertexAI
from
sglang.backend.litellm
import
LiteLLM
# Global Configurations
from
sglang.global_config
import
global_config
...
...
python/sglang/backend/litellm.py
View file @
fb9296f0
...
...
@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend):
self
.
model_name
=
model_name
self
.
chat_template
=
chat_template
or
get_chat_template_by_model_path
(
model_name
)
model_name
)
self
.
client_params
=
{
"api_key"
:
api_key
,
...
...
python/sglang/backend/openai.py
View file @
fb9296f0
import
dataclasses
import
logging
import
time
import
warnings
import
dataclasses
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
...
...
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
def
get_chat_template
(
self
):
return
self
.
chat_template
def
_prepare_spec_execution
(
self
,
sampling_params
:
SglSamplingParams
,
num_api_spec_tokens
:
int
,
spec_var_name
:
str
):
def
_prepare_spec_execution
(
self
,
sampling_params
:
SglSamplingParams
,
num_api_spec_tokens
:
int
,
spec_var_name
:
str
,
):
if
"max_tokens"
not
in
self
.
spec_kwargs
:
self
.
spec_kwargs
[
"max_tokens"
]
=
num_api_spec_tokens
else
:
assert
(
self
.
spec_kwargs
[
"max_tokens"
]
==
num_api_spec_tokens
)
assert
self
.
spec_kwargs
[
"max_tokens"
]
==
num_api_spec_tokens
params
=
sampling_params
.
to_openai_kwargs
()
for
key
,
value
in
params
.
items
():
...
...
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
)
prompt
=
s
.
messages_
else
:
return
self
.
_prepare_spec_execution
(
sampling_params
,
s
.
num_api_spec_tokens
,
spec_var_name
)
return
self
.
_prepare_spec_execution
(
sampling_params
,
s
.
num_api_spec_tokens
,
spec_var_name
)
else
:
prompt
=
s
.
text_
...
...
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
ret_str
=
ret
.
choices
[
0
].
text
ret_token
=
self
.
tokenizer
.
encode
(
ret_str
)[
0
]
self
.
token_usage
.
prompt_tokens
+=
ret
.
usage
.
prompt_tokens
self
.
token_usage
.
completion_tokens
=
ret
.
usage
.
completion_tokens
self
.
token_usage
.
completion_tokens
=
ret
.
usage
.
completion_tokens
# TODO:
# 1. return logits as the scores
...
...
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
return
decision
,
scores
,
None
,
None
def
openai_completion
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
def
openai_completion
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
for
attempt
in
range
(
retries
):
try
:
if
is_chat
:
...
...
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
return
comp
def
openai_completion_stream
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
def
openai_completion_stream
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
for
attempt
in
range
(
retries
):
try
:
if
is_chat
:
if
"stop"
in
kwargs
and
kwargs
[
"stop"
]
is
None
:
kwargs
.
pop
(
"stop"
)
generator
=
client
.
chat
.
completions
.
create
(
messages
=
prompt
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
**
kwargs
messages
=
prompt
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
**
kwargs
,
)
for
ret
in
generator
:
if
len
(
ret
.
choices
)
==
0
:
...
...
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
yield
content
or
""
,
{}
else
:
generator
=
client
.
completions
.
create
(
prompt
=
prompt
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
**
kwargs
prompt
=
prompt
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
**
kwargs
,
)
for
ret
in
generator
:
if
len
(
ret
.
choices
)
==
0
:
...
...
python/sglang/lang/interpreter.py
View file @
fb9296f0
...
...
@@ -507,7 +507,7 @@ class StreamExecutor:
)
return
else
:
# Speculative execution on models with completion interface
else
:
# Speculative execution on models with completion interface
comp
,
meta_info
=
self
.
_spec_gen
(
sampling_params
)
self
.
text_
+=
comp
...
...
python/sglang/lang/ir.py
View file @
fb9296f0
...
...
@@ -81,12 +81,10 @@ class SglSamplingParams:
"top_p"
:
self
.
top_p
,
"top_k"
:
self
.
top_k
,
}
def
to_litellm_kwargs
(
self
):
if
self
.
regex
is
not
None
:
warnings
.
warn
(
"Regular expression is not supported in the LiteLLM backend."
)
warnings
.
warn
(
"Regular expression is not supported in the LiteLLM backend."
)
return
{
"max_tokens"
:
self
.
max_new_tokens
,
"stop"
:
self
.
stop
or
None
,
...
...
python/sglang/launch_server.py
View file @
fb9296f0
...
...
@@ -10,4 +10,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
launch_server
(
server_args
,
None
)
\ No newline at end of file
launch_server
(
server_args
,
None
)
python/sglang/launch_server_llavavid.py
View file @
fb9296f0
"""Launch the inference server for Llava-video model."""
import
argparse
import
multiprocessing
as
mp
...
...
python/sglang/srt/constrained/__init__.py
View file @
fb9296f0
...
...
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union
from
outlines.caching
import
cache
as
disk_cache
from
outlines.caching
import
disable_cache
from
outlines.fsm.guide
import
RegexGuide
from
outlines.fsm.regex
import
FSMInfo
,
make_
deterministic_fsm
,
make_byte_level
_fsm
from
outlines.fsm.regex
import
FSMInfo
,
make_
byte_level_fsm
,
make_deterministic
_fsm
from
outlines.models.transformers
import
TransformerTokenizer
from
pydantic
import
BaseModel
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
fb9296f0
"""Cache for the compressed finite state machine."""
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained.base_cache
import
BaseCache
...
...
python/sglang/srt/constrained/jump_forward.py
View file @
fb9296f0
...
...
@@ -8,11 +8,12 @@ from collections import defaultdict
import
interegular
import
outlines.caching
from
sglang.srt.constrained
import
(
FSMInfo
,
disk_cache
,
make_deterministic_fsm
,
make_byte_level_fsm
,
make_deterministic_fsm
,
)
from
sglang.srt.constrained.base_cache
import
BaseCache
...
...
python/sglang/srt/conversation.py
View file @
fb9296f0
"""Conversation templates."""
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import
dataclasses
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
fb9296f0
"""Utilities for Huggingface Transformers."""
import
functools
import
json
import
os
import
warnings
import
functools
from
typing
import
Optional
,
Union
,
AbstractSet
,
Collection
,
Literal
from
typing
import
AbstractSet
,
Collection
,
Literal
,
Optional
,
Union
from
huggingface_hub
import
snapshot_download
from
transformers
import
(
...
...
@@ -179,6 +179,7 @@ def get_processor(
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
tiktoken
PAT_STR_B
=
r
"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
...
...
@@ -190,7 +191,8 @@ class TiktokenTokenizer:
bytes
(
item
[
"bytes"
]):
item
[
"token"
]
for
item
in
tok_dict
[
"regular_tokens"
]
}
special_tokens
=
{
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
tok_dict
[
"special_tokens"
]
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
tok_dict
[
"special_tokens"
]
}
assert
tok_dict
[
"word_split"
]
==
"V1"
...
...
@@ -202,7 +204,10 @@ class TiktokenTokenizer:
}
if
"default_allowed_special"
in
tok_dict
:
default_allowed_special
=
set
(
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]]
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]
]
)
else
:
default_allowed_special
=
None
...
...
@@ -216,14 +221,20 @@ class TiktokenTokenizer:
self
,
text
:
str
,
*
,
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]]
=
set
(),
# noqa: B006
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]
]
=
set
(),
# noqa: B006
disallowed_special
:
Union
[
Literal
[
"all"
],
Collection
[
str
]]
=
"all"
,
)
->
list
[
int
]:
if
isinstance
(
allowed_special
,
set
):
allowed_special
|=
self
.
_default_allowed_special
return
tiktoken
.
Encoding
.
encode
(
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
disallowed_special
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
disallowed_special
,
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Convert to HF interface
...
...
@@ -237,10 +248,14 @@ class TiktokenTokenizer:
def
decode
(
self
,
x
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
convert_ids_to_tokens
(
self
,
index
):
return
self
.
tokenizer
.
decode_single_token_bytes
(
index
).
decode
(
"utf-8"
,
errors
=
"ignore"
)
\ No newline at end of file
return
self
.
tokenizer
.
decode_single_token_bytes
(
index
).
decode
(
"utf-8"
,
errors
=
"ignore"
)
python/sglang/srt/layers/fused_moe.py
View file @
fb9296f0
...
...
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
triton
import
triton.language
as
tl
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
...
...
@@ -109,12 +108,16 @@ def fused_moe_kernel(
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
b_ptrs
=
(
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
)
if
use_fp8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
...
...
@@ -130,13 +133,12 @@ def fused_moe_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_fp8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
...
...
@@ -147,9 +149,7 @@ def fused_moe_kernel(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_fp8
:
...
...
@@ -159,15 +159,14 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
...
...
@@ -206,32 +205,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
)
->
None
:
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
fused_moe_kernel
[
grid
](
A
,
...
...
@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
...
...
@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int,
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
...
...
@@ -352,40 +358,30 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
is_hip
():
# The MoE kernels are not yet supported on ROCm.
routing_weights
=
torch
.
softmax
(
gating_output
,
dim
=-
1
,
dtype
=
torch
.
float32
)
routing_weights
=
torch
.
softmax
(
gating_output
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
routing_weights
,
topk
,
dim
=-
1
)
else
:
import
vllm._moe_C
as
moe_kernels
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernels
.
topk_softmax
(
topk_weights
,
topk_ids
,
...
...
@@ -400,8 +396,7 @@ def fused_moe(
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
...
...
@@ -415,7 +410,7 @@ def fused_moe(
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
"num_stages"
:
4
,
}
if
M
<=
E
:
...
...
@@ -425,61 +420,72 @@ def fused_moe(
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
"num_stages"
:
4
,
}
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache3
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
compute_type
=
(
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
\ No newline at end of file
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
,
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
python/sglang/srt/layers/logits_processor.py
View file @
fb9296f0
"""Logits processing."""
import
torch
from
torch
import
nn
from
vllm.distributed
import
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
fb9296f0
"""Radix attention."""
import
torch
import
numpy
as
np
import
torch
from
torch
import
nn
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
...
...
@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
,
logit_cap
=-
1
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
,
logit_cap
=-
1
):
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_k_head_num
=
num_kv_heads
...
...
python/sglang/srt/managers/controller/dp_worker.py
View file @
fb9296f0
...
...
@@ -4,7 +4,7 @@ import asyncio
import
logging
import
queue
import
threading
from
typing
import
List
,
Callable
from
typing
import
Callable
,
List
import
uvloop
import
zmq
...
...
@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
])
has_finished
=
any
(
[
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
]
)
if
has_finished
:
await
asyncio
.
sleep
(
self
.
request_dependency_delay
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
...
...
@@ -108,4 +110,4 @@ def start_data_parallel_worker(
step_func
=
model_tp_client
.
step
,
)
worker_thread
.
start
()
return
worker_thread
\ No newline at end of file
return
worker_thread
python/sglang/srt/managers/controller/infer_batch.py
View file @
fb9296f0
"""Meta data for requests and batches"""
import
warnings
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
List
import
warnings
import
numpy
as
np
import
torch
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained
import
RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment