Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bf51ddc6
Unverified
Commit
bf51ddc6
authored
Jan 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Jan 17, 2024
Browse files
Improve docs & Rename Gemini -> VertexAI (#19)
parent
fd7c4792
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
56 additions
and
583 deletions
+56
-583
README.md
README.md
+4
-1
examples/quick_start/gemini_example_complete.py
examples/quick_start/gemini_example_complete.py
+2
-2
examples/quick_start/gemini_example_multimodal_chat.py
examples/quick_start/gemini_example_multimodal_chat.py
+3
-3
examples/quick_start/gemini_example_stream.py
examples/quick_start/gemini_example_stream.py
+2
-2
python/sglang/api.py
python/sglang/api.py
+1
-1
python/sglang/backend/huggingface.py
python/sglang/backend/huggingface.py
+0
-349
python/sglang/backend/tgi.py
python/sglang/backend/tgi.py
+0
-190
python/sglang/backend/vertexai.py
python/sglang/backend/vertexai.py
+20
-25
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+8
-1
python/sglang/srt/managers/router/manager.py
python/sglang/srt/managers/router/manager.py
+2
-2
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/lang/test_vertexai_backend.py
test/lang/test_vertexai_backend.py
+6
-6
No files found.
README.md
View file @
bf51ddc6
# SGLang
|
[
**Blog**
](
https://lmsys.org/blog/2024-01-17-sglang/
)
|
[
**Paper**
](
https://arxiv.org/abs/2312.07104
)
|
SGLang is a structured generation language designed for large language models (LLMs).
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
...
...
@@ -42,7 +43,7 @@ The example below shows how to use sglang to answer a mulit-turn question.
### Using OpenAI Models
Set the OpenAI API Key
```
export OPENAI_API_KEY=sk-
xxxxxx
export OPENAI_API_KEY=sk-
******
```
Then, answer a multi-turn question.
...
...
@@ -100,6 +101,7 @@ for m in state.messages():
### More Examples
Anthropic and VertexAI (Gemini) models are also supported.
You can find more examples at
[
examples/quick_start
](
examples/quick_start
)
.
## Frontend: Structured Generation Langauge (SGLang)
...
...
@@ -251,6 +253,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
-
Mixtral
-
LLaVA
-
`python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
-
AWQ quantization
## Benchmark And Performance
...
...
examples/quick_start/gemini_example_complete.py
View file @
bf51ddc6
from
sglang
import
function
,
gen
,
set_default_backend
,
Gemini
from
sglang
import
function
,
gen
,
set_default_backend
,
VertexAI
@
function
...
...
@@ -16,7 +16,7 @@ A: Rome
s
+=
"A:"
+
gen
(
"answer"
,
stop
=
"
\n
"
,
temperature
=
0
)
set_default_backend
(
Gemini
(
"gemini-pro"
))
set_default_backend
(
VertexAI
(
"gemini-pro"
))
state
=
few_shot_qa
.
run
(
question
=
"What is the capital of the United States?"
)
answer
=
state
[
"answer"
].
strip
().
lower
()
...
...
examples/quick_start/gemini_example_multimodal_chat.py
View file @
bf51ddc6
from
sglang
import
function
,
user
,
assistant
,
gen
,
image
,
set_default_backend
,
Gemini
from
sglang
import
function
,
user
,
assistant
,
gen
,
image
,
set_default_backend
,
VertexAI
@
function
...
...
@@ -6,7 +6,7 @@ def image_qa(s, image_file1, image_file2, question):
s
+=
user
(
image
(
image_file1
)
+
image
(
image_file2
)
+
question
)
s
+=
assistant
(
gen
(
"answer_1"
,
max_tokens
=
256
))
set_default_backend
(
Gemini
(
"gemini-pro-vision"
))
set_default_backend
(
VertexAI
(
"gemini-pro-vision"
))
state
=
image_qa
.
run
(
image_file1
=
"./images/cat.jpeg"
,
...
...
@@ -16,4 +16,4 @@ state = image_qa.run(
)
for
out
in
state
.
text_iter
():
print
(
out
,
end
=
""
,
flush
=
True
)
\ No newline at end of file
print
(
out
,
end
=
""
,
flush
=
True
)
examples/quick_start/gemini_example_stream.py
View file @
bf51ddc6
from
sglang
import
function
,
user
,
assistant
,
gen
,
set_default_backend
,
Gemini
from
sglang
import
function
,
user
,
assistant
,
gen
,
set_default_backend
,
VertexAI
@
function
...
...
@@ -8,7 +8,7 @@ def multi_turn_question(s, question_1, question_2):
s
+=
user
(
question_2
)
s
+=
assistant
(
gen
(
"answer_2"
,
max_tokens
=
256
))
set_default_backend
(
Gemini
(
"gemini-pro"
))
set_default_backend
(
VertexAI
(
"gemini-pro"
))
state
=
multi_turn_question
.
run
(
question_1
=
"What is the capital of the United States?"
,
...
...
python/sglang/api.py
View file @
bf51ddc6
...
...
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
from
sglang.backend.anthropic
import
Anthropic
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.backend.
gemin
i
import
Gemini
from
sglang.backend.
vertexa
i
import
VertexAI
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.global_config
import
global_config
...
...
python/sglang/backend/huggingface.py
deleted
100644 → 0
View file @
fd7c4792
import
functools
from
enum
import
Enum
,
auto
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
import
transformers
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.interpreter
import
ProgramState
from
sglang.utils
import
get_available_gpu_memory
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
StoppingCriteria
,
StoppingCriteriaList
,
)
from
transformersgl.generation.logits_process
import
(
LogitsProcessorList
,
RepetitionPenaltyLogitsProcessor
,
TemperatureLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
)
class
StopReason
(
Enum
):
EOS_TOKEN
=
auto
()
STOP_STR
=
auto
()
LENGTH
=
auto
()
def
load_model
(
model_name
:
str
,
device
,
num_gpus
,
max_gpu_memory
,
model_kwargs
=
None
,
tokenizer_kwargs
=
None
,
):
model_kwargs
=
model_kwargs
or
{}
tokenizer_kwargs
=
tokenizer_kwargs
or
{}
if
device
==
"cuda"
:
model_kwargs
[
"torch_dtype"
]
=
torch
.
float16
if
num_gpus
!=
1
:
model_kwargs
[
"device_map"
]
=
"auto"
if
max_gpu_memory
is
None
:
model_kwargs
[
"device_map"
]
=
"sequential"
# This is important for not the same VRAM sizes
available_gpu_memory
=
[
get_available_gpu_memory
(
i
,
False
)
for
i
in
range
(
num_gpus
)
]
model_kwargs
[
"max_memory"
]
=
{
i
:
str
(
int
(
available_gpu_memory
[
i
]
*
0.85
))
+
"GiB"
for
i
in
range
(
num_gpus
)
}
else
:
model_kwargs
[
"max_memory"
]
=
{
i
:
max_gpu_memory
for
i
in
range
(
num_gpus
)
}
elif
device
==
"cpu"
:
model_kwargs
[
"torch_dtype"
]
=
torch
.
float32
else
:
raise
ValueError
(
f
"Invalid device:
{
device
}
"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
low_cpu_mem_usage
=
True
,
**
model_kwargs
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
**
tokenizer_kwargs
)
if
num_gpus
==
1
:
model
.
to
(
device
).
eval
()
return
model
,
tokenizer
def
prepare_logits_processor
(
temperature
:
float
,
repetition_penalty
:
float
,
top_p
:
float
,
top_k
:
int
)
->
LogitsProcessorList
:
processor_list
=
LogitsProcessorList
()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if
temperature
>=
1e-5
and
temperature
!=
1.0
:
processor_list
.
append
(
TemperatureLogitsWarper
(
temperature
))
if
repetition_penalty
>
1.0
:
processor_list
.
append
(
RepetitionPenaltyLogitsProcessor
(
repetition_penalty
))
if
1e-8
<=
top_p
<
1.0
:
processor_list
.
append
(
TopPLogitsWarper
(
top_p
))
if
top_k
>
0
:
processor_list
.
append
(
TopKLogitsWarper
(
top_k
))
return
processor_list
@
functools
.
lru_cache
def
get_token_healing_mask
(
tokenizer
,
prompt_last_token
):
last_str
=
tokenizer
.
convert_ids_to_tokens
(
prompt_last_token
)
disallowed
=
torch
.
zeros
(
len
(
tokenizer
),
dtype
=
bool
)
for
s
,
t_id
in
tokenizer
.
get_vocab
().
items
():
if
not
s
.
startswith
(
last_str
):
disallowed
[
t_id
]
=
1
return
disallowed
@
functools
.
lru_cache
def
get_int_token_mask
(
tokenizer
):
disallowed
=
torch
.
zeros
(
len
(
tokenizer
),
dtype
=
bool
)
for
s
,
t_id
in
tokenizer
.
get_vocab
().
items
():
s
=
s
.
replace
(
"▁"
,
""
).
strip
()
if
not
(
s
.
isdigit
()
or
len
(
s
)
==
0
or
s
==
","
):
disallowed
[
t_id
]
=
1
disallowed
[
tokenizer
.
eos_token_id
]
=
0
return
disallowed
@
torch
.
inference_mode
()
def
generate_stream
(
model
,
tokenizer
,
prompt
,
max_new_tokens
,
stop
:
List
[
str
],
temperature
,
top_p
,
token_healing
,
logit_mask
=
None
,
):
logits_processor
=
prepare_logits_processor
(
temperature
=
temperature
,
repetition_penalty
=
1.0
,
top_p
=
top_p
,
top_k
=
0
)
device
=
model
.
device
input_ids
=
tokenizer
.
encode
(
prompt
)
output_ids
=
list
(
input_ids
)
prompt_len
=
len
(
prompt
)
# Resolve stop
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
# Token healing
token_healing
=
token_healing
and
len
(
input_ids
)
>
0
if
token_healing
:
token_healing_mask
=
get_token_healing_mask
(
tokenizer
,
input_ids
[
-
1
])
del
output_ids
[
-
1
]
# Generate
past_key_values
=
None
stop_reason
=
None
for
i
in
range
(
max_new_tokens
):
# Forward
if
i
==
0
:
# prefill
out
=
model
(
torch
.
as_tensor
([
output_ids
],
device
=
device
),
use_cache
=
True
)
else
:
# decoding
out
=
model
(
input_ids
=
torch
.
as_tensor
([[
token
]],
device
=
device
),
use_cache
=
True
,
past_key_values
=
past_key_values
,
)
logits
=
out
.
logits
past_key_values
=
out
.
past_key_values
# Logit mask
if
token_healing
and
i
==
0
:
logits
[
0
,
-
1
,
token_healing_mask
]
=
-
1e4
if
logit_mask
is
not
None
:
logits
[
0
,
-
1
,
logit_mask
]
=
-
1e4
# Sample next token
last_token_logits
=
logits_processor
(
None
,
logits
[:,
-
1
,
:])[
0
]
if
temperature
<
1e-5
or
top_p
<
1e-8
:
# greedy
token
=
int
(
torch
.
argmax
(
last_token_logits
))
else
:
probs
=
torch
.
softmax
(
last_token_logits
,
dim
=-
1
)
token
=
int
(
torch
.
multinomial
(
probs
,
num_samples
=
1
))
output_ids
.
append
(
token
)
# Stop condition
if
token
in
stop_token_ids
:
stop_reason
=
StopReason
.
EOS_TOKEN
break
output_str
=
tokenizer
.
decode
(
output_ids
,
skip_special_tokens
=
True
)
for
stop_str
in
stop
:
pos
=
output_str
[
prompt_len
:].
find
(
stop_str
)
if
pos
!=
-
1
:
stop_reason
=
StopReason
.
STOP_STR
output_str
=
output_str
[:
prompt_len
+
pos
]
break
if
stop_reason
:
break
return
output_str
[
prompt_len
:]
class
HuggingFaceTransformers
(
BaseBackend
):
def
__init__
(
self
,
model_name
,
device
=
"cuda"
,
num_gpus
=
1
,
max_gpu_memory
=
None
,
model_kwargs
=
None
,
tokenizer_kwargs
=
None
,
):
self
.
model_name
=
model_name
self
.
device
=
device
self
.
model
,
self
.
tokenizer
=
load_model
(
model_name
,
device
,
num_gpus
,
max_gpu_memory
,
model_kwargs
,
tokenizer_kwargs
)
self
.
chat_template
=
get_chat_template_by_model_path
(
model_name
)
def
get_chat_template
(
self
):
return
self
.
chat_template
def
cache_prefix
(
self
,
prefix_str
:
str
):
pass
def
uncache_prefix
(
self
,
rid
:
str
):
pass
def
end_request
(
self
,
rid
:
str
):
pass
def
begin_program
(
self
,
s
:
ProgramState
):
pass
def
end_program
(
self
,
s
:
ProgramState
):
pass
def
fill
(
self
,
s
:
ProgramState
,
text
:
str
):
return
False
def
generate_internal
(
self
,
prompt
:
str
,
max_tokens
:
int
,
stop
:
Union
[
str
,
List
[
str
]],
temperature
:
float
,
top_p
:
float
,
dtype
:
Optional
[
str
]
=
None
,
):
if
dtype
is
None
:
comp
=
generate_stream
(
self
.
model
,
self
.
tokenizer
,
prompt
,
max_new_tokens
=
max_tokens
,
stop
=
stop
,
temperature
=
temperature
,
top_p
=
top_p
,
token_healing
=
True
,
)
elif
dtype
in
[
str
,
"str"
,
"string"
]:
comp
=
generate_stream
(
self
.
model
,
self
.
tokenizer
,
prompt
+
'"'
,
max_new_tokens
=
max_tokens
,
stop
=
[
'"'
],
temperature
=
temperature
,
top_p
=
top_p
,
token_healing
=
False
,
)
comp
=
'"'
+
comp
+
'"'
elif
dtype
in
[
int
,
"int"
]:
logit_mask
=
get_int_token_mask
(
self
.
tokenizer
)
comp
=
generate_stream
(
self
.
model
,
self
.
tokenizer
,
prompt
,
max_new_tokens
=
max_tokens
,
stop
=
stop
+
[
" "
,
","
],
temperature
=
temperature
,
top_p
=
top_p
,
token_healing
=
False
,
logit_mask
=
logit_mask
,
)
return
comp
def
generate
(
self
,
s
:
ProgramState
,
max_tokens
:
int
,
stop
:
Union
[
str
,
List
[
str
]],
temperature
:
float
,
top_p
:
float
,
dtype
:
Optional
[
str
]
=
None
,
):
prompt
=
s
.
text
comp
=
self
.
generate_internal
(
prompt
,
max_tokens
,
stop
,
temperature
,
top_p
,
dtype
)
return
comp
def
parallel_generate
(
self
,
s
:
ProgramState
,
prefixes
:
List
[
str
],
join_func
:
Callable
,
max_tokens
:
int
,
stop
:
Union
[
str
,
List
[
str
]],
temperature
:
float
,
top_p
:
float
,
dtype
:
Optional
[
str
]
=
None
,
):
prompt
=
s
.
text
parallel_prompts
=
[
prompt
+
prefix
for
prefix
in
prefixes
]
comps
=
[]
for
i
in
range
(
len
(
parallel_prompts
)):
comps
.
append
(
self
.
generate_internal
(
parallel_prompts
[
i
],
max_tokens
,
stop
,
temperature
,
top_p
,
dtype
)
)
joined
=
join_func
([
p
+
c
for
p
,
c
in
zip
(
prefixes
,
comps
)])
return
joined
,
comps
@
torch
.
inference_mode
()
def
select
(
self
,
s
:
ProgramState
,
choices
:
List
[
str
],
temperature
:
float
,
top_p
:
float
):
loss_fct
=
torch
.
nn
.
CrossEntropyLoss
()
prompt
=
s
.
text
prompt_len
=
self
.
tokenizer
.
encode
(
prompt
,
return_tensors
=
"pt"
).
shape
[
1
]
prompt_choices
=
[
prompt
+
choice
for
choice
in
choices
]
scores
=
[]
for
i
in
range
(
len
(
choices
)):
choice_ids
=
self
.
tokenizer
.
encode
(
prompt_choices
[
i
],
return_tensors
=
"pt"
).
to
(
self
.
model
.
device
)
logits
=
self
.
model
(
choice_ids
).
logits
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
logprobs
=
torch
.
log
(
torch
.
softmax
(
logits
,
dim
=-
1
))
idx1
=
torch
.
arange
(
0
,
logits
.
shape
[
1
]
-
1
,
device
=
logits
.
device
)
idx2
=
choice_ids
[
0
,
1
:]
selected_logprobs
=
logprobs
[
0
,
idx1
,
idx2
]
score
=
selected_logprobs
.
mean
().
item
()
scores
.
append
(
score
)
decision
=
choices
[
np
.
argmax
(
scores
)]
return
decision
,
scores
python/sglang/backend/tgi.py
deleted
100644 → 0
View file @
fd7c4792
import
re
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
from
itertools
import
repeat
from
typing
import
List
,
Optional
,
Union
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SglSamplingParams
from
sglang.utils
import
http_request
class
TGI
(
BaseBackend
):
def
__init__
(
self
,
base_url
):
super
().
__init__
()
self
.
base_url
=
base_url
res
=
http_request
(
self
.
base_url
+
"/info"
)
assert
res
.
status_code
==
200
self
.
model_info
=
res
.
json
()
self
.
chat_template
=
get_chat_template_by_model_path
(
self
.
model_info
[
"model_id"
]
)
def
get_model_name
(
self
):
return
self
.
model_info
[
"model_id"
]
def
get_chat_template
(
self
):
return
self
.
chat_template
@
staticmethod
def
adapt_params
(
max_tokens
,
stop
,
sampling_params
,
**
override_params
):
temperature
=
sampling_params
.
temperature
do_sample
=
True
if
temperature
==
0
:
do_sample
=
False
temperature
=
None
if
stop
is
None
:
stop
=
[]
elif
isinstance
(
stop
,
str
):
stop
=
[
stop
]
top_p
=
sampling_params
.
top_p
if
top_p
==
0
:
top_p
=
0.001
if
top_p
==
1
:
top_p
=
0.999
top_k
=
sampling_params
.
top_k
if
top_k
==
-
1
:
top_k
=
None
params
=
{
"decoder_input_details"
:
False
,
"details"
:
False
,
"do_sample"
:
do_sample
,
"max_new_tokens"
:
max_tokens
,
"stop"
:
stop
,
"temperature"
:
temperature
,
"top_p"
:
top_p
,
"top_k"
:
top_k
,
"return_full_text"
:
False
,
}
params
.
update
(
override_params
)
return
params
@
staticmethod
def
_extract_int
(
text
):
words
=
re
.
split
(
"\ |'|\/|\(|\)|
\n
|\.|,"
,
text
)
for
word
in
words
:
try
:
int
(
word
)
return
word
except
ValueError
:
continue
raise
ValueError
@
staticmethod
def
_extract_choice
(
choices
,
text
):
# FIXME: Current only support the case where the choices are single words.
words
=
re
.
split
(
"\ |'|\/|\(|\)|
\n
|\.|,"
,
text
)
for
word
in
words
:
if
word
in
choices
:
return
word
raise
ValueError
@
staticmethod
def
_truncate_to_stop
(
text
,
stop
):
# The stop sequence may not be a single token. In this case TGI will generate
# too many tokens so we need to truncate the output.
if
stop
:
stop
=
[
stop
]
if
isinstance
(
stop
,
str
)
else
stop
for
stop_seq
in
stop
:
pos
=
text
.
find
(
stop_seq
)
if
pos
!=
-
1
:
return
text
[:
pos
]
return
text
def
_make_request
(
self
,
params
):
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
params
)
if
res
.
status_code
!=
200
:
raise
ValueError
(
f
"Error from TGI backend:
{
res
.
text
}
"
)
return
res
.
json
()
def
retry_for_expected
(
self
,
prompt
,
params
,
extract_fn
,
retry
=
5
):
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
failed
=
[]
while
retry
>
0
:
res_json
=
self
.
_make_request
(
{
"inputs"
:
prompt
,
"parameters"
:
params
,
}
)
text
=
res_json
[
"generated_text"
]
try
:
return
extract_fn
(
text
)
except
ValueError
:
retry
-=
1
failed
.
append
(
text
)
msg
=
"="
*
20
+
"
\n
"
msg
+=
f
"Prompt:
\n
{
prompt
}
\n
"
msg
+=
"="
*
20
+
"
\n
"
for
i
,
text
in
enumerate
(
failed
):
msg
+=
f
"====== Try
{
i
+
1
}
:
\n
{
text
}
\n
"
raise
ValueError
(
f
"Model
{
self
.
model_info
[
'model_id'
]
}
served by TGI backend does not generate"
"expected output. Please improve the prompt, increase the temperature, or "
f
"use different models.
\n
{
msg
}
"
)
def
select
(
self
,
s
:
StreamExecutor
,
choices
:
List
[
str
],
sampling_params
:
SglSamplingParams
,
):
decision
=
self
.
retry_for_expected
(
s
.
text_
,
self
.
adapt_params
(
16
,
[],
sampling_params
),
partial
(
self
.
_extract_choice
,
choices
),
)
return
decision
,
[
1
if
choice
==
decision
else
0
for
choice
in
choices
]
def
generate
(
self
,
s
:
StreamExecutor
,
max_tokens
:
int
,
stop
:
Union
[
str
,
List
[
str
]],
sampling_params
:
SglSamplingParams
,
dtype
:
Optional
[
str
]
=
None
,
):
if
dtype
is
None
:
res_json
=
self
.
_make_request
(
{
"inputs"
:
s
.
text_
,
"parameters"
:
self
.
adapt_params
(
max_tokens
,
stop
,
sampling_params
),
}
)
return
self
.
_truncate_to_stop
(
res_json
[
"generated_text"
],
stop
),
{}
if
dtype
in
[
str
,
"str"
,
"string"
]:
stop
=
[
'"'
]
res_json
=
self
.
_make_request
(
{
"inputs"
:
f
'
{
s
.
text_
}
"'
,
"parameters"
:
self
.
adapt_params
(
max_tokens
,
stop
,
sampling_params
),
}
)
return
(
'"'
+
self
.
_truncate_to_stop
(
res_json
[
"generated_text"
],
stop
)
+
'"'
,
{},
)
if
dtype
in
[
int
,
"int"
]:
return
(
self
.
retry_for_expected
(
s
.
text_
,
self
.
adapt_params
(
max_tokens
,
stop
,
sampling_params
),
self
.
_extract_int
,
),
{},
)
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
python/sglang/backend/
gemin
i.py
→
python/sglang/backend/
vertexa
i.py
View file @
bf51ddc6
...
...
@@ -18,13 +18,8 @@ try:
except
ImportError
as
e
:
GenerativeModel
=
e
GEMINI_MODEL_NAMES
=
[
"gemini-pro"
,
"gemini-pro-vision"
,
]
class
Gemini
(
BaseBackend
):
class
VertexAI
(
BaseBackend
):
def
__init__
(
self
,
model_name
):
super
().
__init__
()
...
...
@@ -32,7 +27,7 @@ class Gemini(BaseBackend):
raise
GenerativeModel
project_id
=
os
.
environ
[
"GCP_PROJECT_ID"
]
location
=
os
.
environ
[
"GCP_LOCATION"
]
location
=
os
.
environ
.
get
(
"GCP_LOCATION"
)
vertexai
.
init
(
project
=
project_id
,
location
=
location
)
self
.
model_name
=
model_name
...
...
@@ -47,17 +42,17 @@ class Gemini(BaseBackend):
sampling_params
:
SglSamplingParams
,
):
if
s
.
messages_
:
prompt
=
self
.
messages_to_
gemin
i_input
(
s
.
messages_
)
prompt
=
self
.
messages_to_
vertexa
i_input
(
s
.
messages_
)
else
:
# single-turn
prompt
=
(
self
.
text_to_
gemin
i_input
(
s
.
text_
,
s
.
cur_images
)
self
.
text_to_
vertexa
i_input
(
s
.
text_
,
s
.
cur_images
)
if
s
.
cur_images
else
s
.
text_
)
ret
=
GenerativeModel
(
self
.
model_name
).
generate_content
(
prompt
,
generation_config
=
GenerationConfig
(
**
sampling_params
.
to_
gemin
i_kwargs
()),
generation_config
=
GenerationConfig
(
**
sampling_params
.
to_
vertexa
i_kwargs
()),
)
comp
=
ret
.
text
...
...
@@ -70,23 +65,23 @@ class Gemini(BaseBackend):
sampling_params
:
SglSamplingParams
,
):
if
s
.
messages_
:
prompt
=
self
.
messages_to_
gemin
i_input
(
s
.
messages_
)
prompt
=
self
.
messages_to_
vertexa
i_input
(
s
.
messages_
)
else
:
# single-turn
prompt
=
(
self
.
text_to_
gemin
i_input
(
s
.
text_
,
s
.
cur_images
)
self
.
text_to_
vertexa
i_input
(
s
.
text_
,
s
.
cur_images
)
if
s
.
cur_images
else
s
.
text_
)
generator
=
GenerativeModel
(
self
.
model_name
).
generate_content
(
prompt
,
stream
=
True
,
generation_config
=
GenerationConfig
(
**
sampling_params
.
to_
gemin
i_kwargs
()),
generation_config
=
GenerationConfig
(
**
sampling_params
.
to_
vertexa
i_kwargs
()),
)
for
ret
in
generator
:
yield
ret
.
text
,
{}
def
text_to_
gemin
i_input
(
self
,
text
,
images
):
def
text_to_
vertexa
i_input
(
self
,
text
,
images
):
input
=
[]
# split with image token
text_segs
=
text
.
split
(
self
.
chat_template
.
image_token
)
...
...
@@ -100,9 +95,9 @@ class Gemini(BaseBackend):
input
.
append
(
text_seg
)
return
input
def
messages_to_
gemin
i_input
(
self
,
messages
):
gemin
i_message
=
[]
# from openai message format to
gemin
i message format
def
messages_to_
vertexa
i_input
(
self
,
messages
):
vertexa
i_message
=
[]
# from openai message format to
vertexa
i message format
for
msg
in
messages
:
if
isinstance
(
msg
[
"content"
],
str
):
text
=
msg
[
"content"
]
...
...
@@ -110,14 +105,14 @@ class Gemini(BaseBackend):
text
=
msg
[
"content"
][
0
][
"text"
]
if
msg
[
"role"
]
==
"system"
:
warnings
.
warn
(
"Warning: system prompt is not supported in
Gemini
."
)
gemin
i_message
.
append
(
warnings
.
warn
(
"Warning: system prompt is not supported in
VertexAI
."
)
vertexa
i_message
.
append
(
{
"role"
:
"user"
,
"parts"
:
[{
"text"
:
"System prompt: "
+
text
}],
}
)
gemin
i_message
.
append
(
vertexa
i_message
.
append
(
{
"role"
:
"model"
,
"parts"
:
[{
"text"
:
"Understood."
}],
...
...
@@ -125,12 +120,12 @@ class Gemini(BaseBackend):
)
continue
if
msg
[
"role"
]
==
"user"
:
gemin
i_msg
=
{
vertexa
i_msg
=
{
"role"
:
"user"
,
"parts"
:
[{
"text"
:
text
}],
}
elif
msg
[
"role"
]
==
"assistant"
:
gemin
i_msg
=
{
vertexa
i_msg
=
{
"role"
:
"model"
,
"parts"
:
[{
"text"
:
text
}],
}
...
...
@@ -139,7 +134,7 @@ class Gemini(BaseBackend):
if
isinstance
(
msg
[
"content"
],
list
)
and
len
(
msg
[
"content"
])
>
1
:
for
image
in
msg
[
"content"
][
1
:]:
assert
image
[
"type"
]
==
"image_url"
gemin
i_msg
[
"parts"
].
append
(
vertexa
i_msg
[
"parts"
].
append
(
{
"inline_data"
:
{
"data"
:
image
[
"image_url"
][
"url"
].
split
(
","
)[
1
],
...
...
@@ -148,5 +143,5 @@ class Gemini(BaseBackend):
}
)
gemin
i_message
.
append
(
gemin
i_msg
)
return
gemin
i_message
vertexa
i_message
.
append
(
vertexa
i_msg
)
return
vertexa
i_message
python/sglang/lang/ir.py
View file @
bf51ddc6
...
...
@@ -2,6 +2,7 @@
import
dataclasses
import
inspect
import
warnings
from
typing
import
List
,
Optional
,
Union
from
sglang.global_config
import
global_config
...
...
@@ -40,6 +41,8 @@ class SglSamplingParams:
def
to_openai_kwargs
(
self
):
# OpenAI does not support top_k, so we drop it here
if
self
.
regex
is
not
None
:
warnings
.
warn
(
"Regular expression is not supported in the OpenAI backend."
)
return
{
"max_tokens"
:
self
.
max_new_tokens
,
"stop"
:
self
.
stop
or
None
,
...
...
@@ -49,7 +52,9 @@ class SglSamplingParams:
"presence_penalty"
:
self
.
presence_penalty
,
}
def
to_gemini_kwargs
(
self
):
def
to_vertexai_kwargs
(
self
):
if
self
.
regex
is
not
None
:
warnings
.
warn
(
"Regular expression is not supported in the VertexAI backend."
)
return
{
"candidate_count"
:
1
,
"max_output_tokens"
:
self
.
max_new_tokens
,
...
...
@@ -61,6 +66,8 @@ class SglSamplingParams:
def
to_anthropic_kwargs
(
self
):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
if
self
.
regex
is
not
None
:
warnings
.
warn
(
"Regular expression is not supported in the Anthropic backend."
)
return
{
"max_tokens_to_sample"
:
self
.
max_new_tokens
,
"stop_sequences"
:
self
.
stop
,
...
...
python/sglang/srt/managers/router/manager.py
View file @
bf51ddc6
...
...
@@ -28,7 +28,7 @@ class RouterManager:
self
.
model_client
=
model_client
self
.
recv_reqs
=
[]
# Init
S
ome
C
onfigs
# Init
s
ome
c
onfigs
self
.
extend_dependency_time
=
GLOBAL_BACKEND_CONFIG
.
extend_dependency_time
async
def
loop_for_forward
(
self
):
...
...
@@ -46,7 +46,7 @@ class RouterManager:
if
has_finished
:
await
asyncio
.
sleep
(
self
.
extend_dependency_time
)
await
asyncio
.
sleep
(
0.00
1
)
await
asyncio
.
sleep
(
0.00
06
)
async
def
loop_for_recv_requests
(
self
):
while
True
:
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
bf51ddc6
...
...
@@ -108,7 +108,7 @@ class ModelRpcServer(rpyc.Service):
self
.
running_batch
:
Batch
=
None
self
.
out_pyobjs
=
[]
self
.
decode_forward_ct
=
0
self
.
stream_interval
=
2
self
.
stream_interval
=
server_args
.
stream_interval
# Init the FSM cache for constrained generation
self
.
regex_fsm_cache
=
FSMCache
(
self
.
tokenizer
)
...
...
python/sglang/srt/server_args.py
View file @
bf51ddc6
...
...
@@ -17,6 +17,7 @@ class ServerArgs:
model_mode
:
List
[
str
]
=
()
schedule_heuristic
:
str
=
"lpm"
random_seed
:
int
=
42
stream_interval
:
int
=
2
disable_log_stats
:
bool
=
False
log_stats_interval
:
int
=
10
log_level
:
str
=
"info"
...
...
@@ -108,6 +109,12 @@ class ServerArgs:
default
=
ServerArgs
.
random_seed
,
help
=
"Random seed."
,
)
parser
.
add_argument
(
"--stream-interval"
,
type
=
int
,
default
=
ServerArgs
.
random_seed
,
help
=
"The interval in terms of token length for streaming"
,
)
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
...
...
test/lang/test_
gemin
i_backend.py
→
test/lang/test_
vertexa
i_backend.py
View file @
bf51ddc6
...
...
@@ -10,10 +10,10 @@ from sglang.test.test_programs import (
test_stream
,
)
from
sglang
import
Gemini
,
set_default_backend
from
sglang
import
VertexAI
,
set_default_backend
class
Test
Gemini
Backend
(
unittest
.
TestCase
):
class
Test
VertexAI
Backend
(
unittest
.
TestCase
):
backend
=
None
chat_backend
=
None
chat_vision_backend
=
None
...
...
@@ -22,9 +22,9 @@ class TestGeminiBackend(unittest.TestCase):
cls
=
type
(
self
)
if
cls
.
backend
is
None
:
cls
.
backend
=
Gemini
(
"gemini-pro"
)
cls
.
chat_backend
=
Gemini
(
"gemini-pro"
)
cls
.
chat_vision_backend
=
Gemini
(
"gemini-pro-vision"
)
cls
.
backend
=
VertexAI
(
"gemini-pro"
)
cls
.
chat_backend
=
VertexAI
(
"gemini-pro"
)
cls
.
chat_vision_backend
=
VertexAI
(
"gemini-pro-vision"
)
def
test_few_shot_qa
(
self
):
set_default_backend
(
self
.
backend
)
...
...
@@ -61,6 +61,6 @@ if __name__ == "__main__":
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = Test
Gemini
Backend()
# t = Test
VertexAI
Backend()
# t.setUp()
# t.test_stream()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment