Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb9296f0
Unverified
Commit
fb9296f0
authored
Jun 12, 2024
by
Ying Sheng
Committed by
GitHub
Jun 12, 2024
Browse files
Higher priority for user input of max_prefill_tokens & format (#540)
parent
1374334d
Changes
50
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
227 additions
and
172 deletions
+227
-172
benchmark/gsm8k/bench_other.py
benchmark/gsm8k/bench_other.py
+1
-1
benchmark/latency_throughput/bench_throughput.py
benchmark/latency_throughput/bench_throughput.py
+24
-10
benchmark/mmlu/bench_other.py
benchmark/mmlu/bench_other.py
+1
-1
python/sglang/__init__.py
python/sglang/__init__.py
+1
-1
python/sglang/backend/litellm.py
python/sglang/backend/litellm.py
+2
-1
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+26
-15
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+1
-1
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+2
-4
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-1
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+1
-0
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+1
-1
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+1
-0
python/sglang/srt/constrained/jump_forward.py
python/sglang/srt/constrained/jump_forward.py
+2
-1
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+1
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+23
-8
python/sglang/srt/layers/fused_moe.py
python/sglang/srt/layers/fused_moe.py
+125
-119
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-0
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+5
-2
python/sglang/srt/managers/controller/dp_worker.py
python/sglang/srt/managers/controller/dp_worker.py
+5
-3
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+3
-3
No files found.
benchmark/gsm8k/bench_other.py
View file @
fb9296f0
...
...
@@ -65,7 +65,7 @@ def main(args):
def
get_one_answer
(
i
):
answer
=
call_generate
(
prompt
=
few_shot_examples
+
questions
[
i
],
#prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
#
prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
temperature
=
0
,
max_tokens
=
256
,
stop
=
"Question"
,
...
...
benchmark/latency_throughput/bench_throughput.py
View file @
fb9296f0
...
...
@@ -158,7 +158,9 @@ async def send_request(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
while
True
:
async
with
session
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
)
as
response
:
async
with
session
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
)
as
response
:
chunks
=
[]
async
for
chunk
,
_
in
response
.
content
.
iter_chunks
():
chunks
.
append
(
chunk
)
...
...
@@ -228,19 +230,32 @@ def main(args: argparse.Namespace):
np
.
random
.
seed
(
args
.
seed
)
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
dataset
:
input_requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
else
:
input_lens
=
np
.
random
.
randint
(
int
(
args
.
input_len
*
args
.
range_ratio
),
args
.
input_len
+
1
,
size
=
args
.
num_prompts
)
int
(
args
.
input_len
*
args
.
range_ratio
),
args
.
input_len
+
1
,
size
=
args
.
num_prompts
,
)
output_lens
=
np
.
random
.
randint
(
int
(
args
.
output_len
*
args
.
range_ratio
),
args
.
output_len
+
1
,
size
=
args
.
num_prompts
)
int
(
args
.
output_len
*
args
.
range_ratio
),
args
.
output_len
+
1
,
size
=
args
.
num_prompts
,
)
offsets
=
np
.
random
.
randint
(
0
,
tokenizer
.
vocab_size
,
size
=
args
.
num_prompts
)
input_requests
=
[]
for
i
in
range
(
args
.
num_prompts
):
prompt
=
tokenizer
.
decode
([(
offsets
[
i
]
+
i
+
j
)
%
tokenizer
.
vocab_size
for
j
in
range
(
input_lens
[
i
])])
prompt
=
tokenizer
.
decode
(
[
(
offsets
[
i
]
+
i
+
j
)
%
tokenizer
.
vocab_size
for
j
in
range
(
input_lens
[
i
])
]
)
input_requests
.
append
((
prompt
,
int
(
input_lens
[
i
]),
int
(
output_lens
[
i
])))
benchmark_start_time
=
time
.
perf_counter
()
...
...
@@ -287,16 +302,15 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
2048
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--range-ratio"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
"--tokenizer"
,
type
=
str
,
default
=
"NousResearch/Meta-Llama-3-8B"
,
help
=
"Name or path of the tokenizer."
help
=
"Name or path of the tokenizer."
,
)
parser
.
add_argument
(
"--best-of"
,
...
...
benchmark/mmlu/bench_other.py
View file @
fb9296f0
...
...
@@ -170,4 +170,4 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--data_dir"
,
type
=
str
,
default
=
"data"
)
parser
.
add_argument
(
"--nsub"
,
type
=
int
,
default
=
60
)
args
=
add_common_other_args_and_parse
(
parser
)
main
(
args
)
\ No newline at end of file
main
(
args
)
python/sglang/__init__.py
View file @
fb9296f0
...
...
@@ -24,10 +24,10 @@ from sglang.api import (
# SGL Backends
from
sglang.backend.anthropic
import
Anthropic
from
sglang.backend.litellm
import
LiteLLM
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.backend.vertexai
import
VertexAI
from
sglang.backend.litellm
import
LiteLLM
# Global Configurations
from
sglang.global_config
import
global_config
...
...
python/sglang/backend/litellm.py
View file @
fb9296f0
...
...
@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend):
self
.
model_name
=
model_name
self
.
chat_template
=
chat_template
or
get_chat_template_by_model_path
(
model_name
)
model_name
)
self
.
client_params
=
{
"api_key"
:
api_key
,
...
...
python/sglang/backend/openai.py
View file @
fb9296f0
import
dataclasses
import
logging
import
time
import
warnings
import
dataclasses
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
...
...
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
def
get_chat_template
(
self
):
return
self
.
chat_template
def
_prepare_spec_execution
(
self
,
sampling_params
:
SglSamplingParams
,
num_api_spec_tokens
:
int
,
spec_var_name
:
str
):
def
_prepare_spec_execution
(
self
,
sampling_params
:
SglSamplingParams
,
num_api_spec_tokens
:
int
,
spec_var_name
:
str
,
):
if
"max_tokens"
not
in
self
.
spec_kwargs
:
self
.
spec_kwargs
[
"max_tokens"
]
=
num_api_spec_tokens
else
:
assert
(
self
.
spec_kwargs
[
"max_tokens"
]
==
num_api_spec_tokens
)
assert
self
.
spec_kwargs
[
"max_tokens"
]
==
num_api_spec_tokens
params
=
sampling_params
.
to_openai_kwargs
()
for
key
,
value
in
params
.
items
():
...
...
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
)
prompt
=
s
.
messages_
else
:
return
self
.
_prepare_spec_execution
(
sampling_params
,
s
.
num_api_spec_tokens
,
spec_var_name
)
return
self
.
_prepare_spec_execution
(
sampling_params
,
s
.
num_api_spec_tokens
,
spec_var_name
)
else
:
prompt
=
s
.
text_
...
...
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
ret_str
=
ret
.
choices
[
0
].
text
ret_token
=
self
.
tokenizer
.
encode
(
ret_str
)[
0
]
self
.
token_usage
.
prompt_tokens
+=
ret
.
usage
.
prompt_tokens
self
.
token_usage
.
completion_tokens
=
ret
.
usage
.
completion_tokens
self
.
token_usage
.
completion_tokens
=
ret
.
usage
.
completion_tokens
# TODO:
# 1. return logits as the scores
...
...
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
return
decision
,
scores
,
None
,
None
def
openai_completion
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
def
openai_completion
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
for
attempt
in
range
(
retries
):
try
:
if
is_chat
:
...
...
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
return
comp
def
openai_completion_stream
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
def
openai_completion_stream
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
for
attempt
in
range
(
retries
):
try
:
if
is_chat
:
if
"stop"
in
kwargs
and
kwargs
[
"stop"
]
is
None
:
kwargs
.
pop
(
"stop"
)
generator
=
client
.
chat
.
completions
.
create
(
messages
=
prompt
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
**
kwargs
messages
=
prompt
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
**
kwargs
,
)
for
ret
in
generator
:
if
len
(
ret
.
choices
)
==
0
:
...
...
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
yield
content
or
""
,
{}
else
:
generator
=
client
.
completions
.
create
(
prompt
=
prompt
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
**
kwargs
prompt
=
prompt
,
stream
=
True
,
stream_options
=
{
"include_usage"
:
True
},
**
kwargs
,
)
for
ret
in
generator
:
if
len
(
ret
.
choices
)
==
0
:
...
...
python/sglang/lang/interpreter.py
View file @
fb9296f0
...
...
@@ -507,7 +507,7 @@ class StreamExecutor:
)
return
else
:
# Speculative execution on models with completion interface
else
:
# Speculative execution on models with completion interface
comp
,
meta_info
=
self
.
_spec_gen
(
sampling_params
)
self
.
text_
+=
comp
...
...
python/sglang/lang/ir.py
View file @
fb9296f0
...
...
@@ -81,12 +81,10 @@ class SglSamplingParams:
"top_p"
:
self
.
top_p
,
"top_k"
:
self
.
top_k
,
}
def
to_litellm_kwargs
(
self
):
if
self
.
regex
is
not
None
:
warnings
.
warn
(
"Regular expression is not supported in the LiteLLM backend."
)
warnings
.
warn
(
"Regular expression is not supported in the LiteLLM backend."
)
return
{
"max_tokens"
:
self
.
max_new_tokens
,
"stop"
:
self
.
stop
or
None
,
...
...
python/sglang/launch_server.py
View file @
fb9296f0
...
...
@@ -10,4 +10,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
server_args
=
ServerArgs
.
from_cli_args
(
args
)
launch_server
(
server_args
,
None
)
\ No newline at end of file
launch_server
(
server_args
,
None
)
python/sglang/launch_server_llavavid.py
View file @
fb9296f0
"""Launch the inference server for Llava-video model."""
import
argparse
import
multiprocessing
as
mp
...
...
python/sglang/srt/constrained/__init__.py
View file @
fb9296f0
...
...
@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union
from
outlines.caching
import
cache
as
disk_cache
from
outlines.caching
import
disable_cache
from
outlines.fsm.guide
import
RegexGuide
from
outlines.fsm.regex
import
FSMInfo
,
make_
deterministic_fsm
,
make_byte_level
_fsm
from
outlines.fsm.regex
import
FSMInfo
,
make_
byte_level_fsm
,
make_deterministic
_fsm
from
outlines.models.transformers
import
TransformerTokenizer
from
pydantic
import
BaseModel
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
fb9296f0
"""Cache for the compressed finite state machine."""
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained.base_cache
import
BaseCache
...
...
python/sglang/srt/constrained/jump_forward.py
View file @
fb9296f0
...
...
@@ -8,11 +8,12 @@ from collections import defaultdict
import
interegular
import
outlines.caching
from
sglang.srt.constrained
import
(
FSMInfo
,
disk_cache
,
make_deterministic_fsm
,
make_byte_level_fsm
,
make_deterministic_fsm
,
)
from
sglang.srt.constrained.base_cache
import
BaseCache
...
...
python/sglang/srt/conversation.py
View file @
fb9296f0
"""Conversation templates."""
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import
dataclasses
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
fb9296f0
"""Utilities for Huggingface Transformers."""
import
functools
import
json
import
os
import
warnings
import
functools
from
typing
import
Optional
,
Union
,
AbstractSet
,
Collection
,
Literal
from
typing
import
AbstractSet
,
Collection
,
Literal
,
Optional
,
Union
from
huggingface_hub
import
snapshot_download
from
transformers
import
(
...
...
@@ -179,6 +179,7 @@ def get_processor(
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
tiktoken
PAT_STR_B
=
r
"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
...
...
@@ -190,7 +191,8 @@ class TiktokenTokenizer:
bytes
(
item
[
"bytes"
]):
item
[
"token"
]
for
item
in
tok_dict
[
"regular_tokens"
]
}
special_tokens
=
{
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
tok_dict
[
"special_tokens"
]
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
tok_dict
[
"special_tokens"
]
}
assert
tok_dict
[
"word_split"
]
==
"V1"
...
...
@@ -202,7 +204,10 @@ class TiktokenTokenizer:
}
if
"default_allowed_special"
in
tok_dict
:
default_allowed_special
=
set
(
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]]
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]
]
)
else
:
default_allowed_special
=
None
...
...
@@ -216,14 +221,20 @@ class TiktokenTokenizer:
self
,
text
:
str
,
*
,
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]]
=
set
(),
# noqa: B006
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]
]
=
set
(),
# noqa: B006
disallowed_special
:
Union
[
Literal
[
"all"
],
Collection
[
str
]]
=
"all"
,
)
->
list
[
int
]:
if
isinstance
(
allowed_special
,
set
):
allowed_special
|=
self
.
_default_allowed_special
return
tiktoken
.
Encoding
.
encode
(
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
disallowed_special
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
disallowed_special
,
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Convert to HF interface
...
...
@@ -237,10 +248,14 @@ class TiktokenTokenizer:
def
decode
(
self
,
x
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
convert_ids_to_tokens
(
self
,
index
):
return
self
.
tokenizer
.
decode_single_token_bytes
(
index
).
decode
(
"utf-8"
,
errors
=
"ignore"
)
\ No newline at end of file
return
self
.
tokenizer
.
decode_single_token_bytes
(
index
).
decode
(
"utf-8"
,
errors
=
"ignore"
)
python/sglang/srt/layers/fused_moe.py
View file @
fb9296f0
...
...
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
triton
import
triton.language
as
tl
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
...
...
@@ -109,12 +108,16 @@ def fused_moe_kernel(
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
b_ptrs
=
(
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
)
if
use_fp8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
...
...
@@ -130,13 +133,12 @@ def fused_moe_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_fp8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
...
...
@@ -147,9 +149,7 @@ def fused_moe_kernel(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_fp8
:
...
...
@@ -159,15 +159,14 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
...
...
@@ -206,32 +205,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
)
->
None
:
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
fused_moe_kernel
[
grid
](
A
,
...
...
@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
...
...
@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int,
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
...
...
@@ -352,40 +358,30 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
is_hip
():
# The MoE kernels are not yet supported on ROCm.
routing_weights
=
torch
.
softmax
(
gating_output
,
dim
=-
1
,
dtype
=
torch
.
float32
)
routing_weights
=
torch
.
softmax
(
gating_output
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
routing_weights
,
topk
,
dim
=-
1
)
else
:
import
vllm._moe_C
as
moe_kernels
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernels
.
topk_softmax
(
topk_weights
,
topk_ids
,
...
...
@@ -400,8 +396,7 @@ def fused_moe(
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
...
...
@@ -415,7 +410,7 @@ def fused_moe(
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
"num_stages"
:
4
,
}
if
M
<=
E
:
...
...
@@ -425,61 +420,72 @@ def fused_moe(
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
"num_stages"
:
4
,
}
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache3
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
compute_type
=
(
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
\ No newline at end of file
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
,
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
python/sglang/srt/layers/logits_processor.py
View file @
fb9296f0
"""Logits processing."""
import
torch
from
torch
import
nn
from
vllm.distributed
import
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
fb9296f0
"""Radix attention."""
import
torch
import
numpy
as
np
import
torch
from
torch
import
nn
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
...
...
@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
,
logit_cap
=-
1
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
,
logit_cap
=-
1
):
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_k_head_num
=
num_kv_heads
...
...
python/sglang/srt/managers/controller/dp_worker.py
View file @
fb9296f0
...
...
@@ -4,7 +4,7 @@ import asyncio
import
logging
import
queue
import
threading
from
typing
import
List
,
Callable
from
typing
import
Callable
,
List
import
uvloop
import
zmq
...
...
@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
])
has_finished
=
any
(
[
obj
.
finished_reason
is
not
None
for
obj
in
out_pyobjs
]
)
if
has_finished
:
await
asyncio
.
sleep
(
self
.
request_dependency_delay
)
await
asyncio
.
sleep
(
global_config
.
wait_for_new_request_delay
)
...
...
@@ -108,4 +110,4 @@ def start_data_parallel_worker(
step_func
=
model_tp_client
.
step
,
)
worker_thread
.
start
()
return
worker_thread
\ No newline at end of file
return
worker_thread
python/sglang/srt/managers/controller/infer_batch.py
View file @
fb9296f0
"""Meta data for requests and batches"""
import
warnings
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
List
import
warnings
import
numpy
as
np
import
torch
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained
import
RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment