Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fd7c4792
Unverified
Commit
fd7c4792
authored
Jan 16, 2024
by
shiyi.c_98
Committed by
GitHub
Jan 16, 2024
Browse files
Gemini Backend (#9)
Co-authored-by:
Ying Sheng
<
sqy1415@gmail.com
>
parent
c4707f1b
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
311 additions
and
2 deletions
+311
-2
examples/quick_start/gemini_example_complete.py
examples/quick_start/gemini_example_complete.py
+26
-0
examples/quick_start/gemini_example_multimodal_chat.py
examples/quick_start/gemini_example_multimodal_chat.py
+19
-0
examples/quick_start/gemini_example_stream.py
examples/quick_start/gemini_example_stream.py
+20
-0
examples/quick_start/images/cat.jpeg
examples/quick_start/images/cat.jpeg
+0
-0
examples/quick_start/images/dog.jpeg
examples/quick_start/images/dog.jpeg
+0
-0
python/sglang/api.py
python/sglang/api.py
+1
-0
python/sglang/backend/gemini.py
python/sglang/backend/gemini.py
+152
-0
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+1
-0
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+10
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+1
-1
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+4
-1
test/lang/test_gemini_backend.py
test/lang/test_gemini_backend.py
+66
-0
test/lang/test_openai_backend.py
test/lang/test_openai_backend.py
+11
-0
No files found.
examples/quick_start/gemini_example_complete.py
0 → 100644
View file @
fd7c4792
from
sglang
import
function
,
gen
,
set_default_backend
,
Gemini
@
function
def
few_shot_qa
(
s
,
question
):
s
+=
(
"""The following are questions with answers.
Q: What is the capital of France?
A: Paris
Q: What is the capital of Germany?
A: Berlin
Q: What is the capital of Italy?
A: Rome
"""
)
s
+=
"Q: "
+
question
+
"
\n
"
s
+=
"A:"
+
gen
(
"answer"
,
stop
=
"
\n
"
,
temperature
=
0
)
set_default_backend
(
Gemini
(
"gemini-pro"
))
state
=
few_shot_qa
.
run
(
question
=
"What is the capital of the United States?"
)
answer
=
state
[
"answer"
].
strip
().
lower
()
assert
"washington"
in
answer
,
f
"answer:
{
state
[
'answer'
]
}
"
print
(
state
.
text
())
examples/quick_start/gemini_example_multimodal_chat.py
0 → 100644
View file @
fd7c4792
from
sglang
import
function
,
user
,
assistant
,
gen
,
image
,
set_default_backend
,
Gemini
@
function
def
image_qa
(
s
,
image_file1
,
image_file2
,
question
):
s
+=
user
(
image
(
image_file1
)
+
image
(
image_file2
)
+
question
)
s
+=
assistant
(
gen
(
"answer_1"
,
max_tokens
=
256
))
set_default_backend
(
Gemini
(
"gemini-pro-vision"
))
state
=
image_qa
.
run
(
image_file1
=
"./images/cat.jpeg"
,
image_file2
=
"./images/dog.jpeg"
,
question
=
"Describe difference of the 2 images in one sentence."
,
stream
=
True
)
for
out
in
state
.
text_iter
():
print
(
out
,
end
=
""
,
flush
=
True
)
\ No newline at end of file
examples/quick_start/gemini_example_stream.py
0 → 100644
View file @
fd7c4792
from
sglang
import
function
,
user
,
assistant
,
gen
,
set_default_backend
,
Gemini
@
function
def
multi_turn_question
(
s
,
question_1
,
question_2
):
s
+=
user
(
question_1
)
s
+=
assistant
(
gen
(
"answer_1"
,
max_tokens
=
256
))
s
+=
user
(
question_2
)
s
+=
assistant
(
gen
(
"answer_2"
,
max_tokens
=
256
))
set_default_backend
(
Gemini
(
"gemini-pro"
))
state
=
multi_turn_question
.
run
(
question_1
=
"What is the capital of the United States?"
,
question_2
=
"List two local attractions."
,
stream
=
True
)
for
out
in
state
.
text_iter
():
print
(
out
,
end
=
""
,
flush
=
True
)
examples/quick_start/images/cat.jpeg
0 → 100644
View file @
fd7c4792
337 KB
examples/quick_start/images/dog.jpeg
0 → 100644
View file @
fd7c4792
407 KB
python/sglang/api.py
View file @
fd7c4792
...
@@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Union
...
@@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Union
from
sglang.backend.anthropic
import
Anthropic
from
sglang.backend.anthropic
import
Anthropic
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.backend.gemini
import
Gemini
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
...
...
python/sglang/backend/gemini.py
0 → 100644
View file @
fd7c4792
import
os
import
warnings
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SglSamplingParams
try
:
import
vertexai
from
vertexai.preview.generative_models
import
(
GenerationConfig
,
GenerativeModel
,
Image
,
)
except
ImportError
as
e
:
GenerativeModel
=
e
GEMINI_MODEL_NAMES
=
[
"gemini-pro"
,
"gemini-pro-vision"
,
]
class
Gemini
(
BaseBackend
):
def
__init__
(
self
,
model_name
):
super
().
__init__
()
if
isinstance
(
GenerativeModel
,
Exception
):
raise
GenerativeModel
project_id
=
os
.
environ
[
"GCP_PROJECT_ID"
]
location
=
os
.
environ
[
"GCP_LOCATION"
]
vertexai
.
init
(
project
=
project_id
,
location
=
location
)
self
.
model_name
=
model_name
self
.
chat_template
=
get_chat_template
(
"default"
)
def
get_chat_template
(
self
):
return
self
.
chat_template
def
generate
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SglSamplingParams
,
):
if
s
.
messages_
:
prompt
=
self
.
messages_to_gemini_input
(
s
.
messages_
)
else
:
# single-turn
prompt
=
(
self
.
text_to_gemini_input
(
s
.
text_
,
s
.
cur_images
)
if
s
.
cur_images
else
s
.
text_
)
ret
=
GenerativeModel
(
self
.
model_name
).
generate_content
(
prompt
,
generation_config
=
GenerationConfig
(
**
sampling_params
.
to_gemini_kwargs
()),
)
comp
=
ret
.
text
return
comp
,
{}
def
generate_stream
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SglSamplingParams
,
):
if
s
.
messages_
:
prompt
=
self
.
messages_to_gemini_input
(
s
.
messages_
)
else
:
# single-turn
prompt
=
(
self
.
text_to_gemini_input
(
s
.
text_
,
s
.
cur_images
)
if
s
.
cur_images
else
s
.
text_
)
generator
=
GenerativeModel
(
self
.
model_name
).
generate_content
(
prompt
,
stream
=
True
,
generation_config
=
GenerationConfig
(
**
sampling_params
.
to_gemini_kwargs
()),
)
for
ret
in
generator
:
yield
ret
.
text
,
{}
def
text_to_gemini_input
(
self
,
text
,
images
):
input
=
[]
# split with image token
text_segs
=
text
.
split
(
self
.
chat_template
.
image_token
)
for
image_path
,
image_base64_data
in
images
:
text_seg
=
text_segs
.
pop
(
0
)
if
text_seg
!=
""
:
input
.
append
(
text_seg
)
input
.
append
(
Image
.
from_bytes
(
image_base64_data
))
text_seg
=
text_segs
.
pop
(
0
)
if
text_seg
!=
""
:
input
.
append
(
text_seg
)
return
input
def
messages_to_gemini_input
(
self
,
messages
):
gemini_message
=
[]
# from openai message format to gemini message format
for
msg
in
messages
:
if
isinstance
(
msg
[
"content"
],
str
):
text
=
msg
[
"content"
]
else
:
text
=
msg
[
"content"
][
0
][
"text"
]
if
msg
[
"role"
]
==
"system"
:
warnings
.
warn
(
"Warning: system prompt is not supported in Gemini."
)
gemini_message
.
append
(
{
"role"
:
"user"
,
"parts"
:
[{
"text"
:
"System prompt: "
+
text
}],
}
)
gemini_message
.
append
(
{
"role"
:
"model"
,
"parts"
:
[{
"text"
:
"Understood."
}],
}
)
continue
if
msg
[
"role"
]
==
"user"
:
gemini_msg
=
{
"role"
:
"user"
,
"parts"
:
[{
"text"
:
text
}],
}
elif
msg
[
"role"
]
==
"assistant"
:
gemini_msg
=
{
"role"
:
"model"
,
"parts"
:
[{
"text"
:
text
}],
}
# images
if
isinstance
(
msg
[
"content"
],
list
)
and
len
(
msg
[
"content"
])
>
1
:
for
image
in
msg
[
"content"
][
1
:]:
assert
image
[
"type"
]
==
"image_url"
gemini_msg
[
"parts"
].
append
(
{
"inline_data"
:
{
"data"
:
image
[
"image_url"
][
"url"
].
split
(
","
)[
1
],
"mime_type"
:
"image/jpeg"
,
}
}
)
gemini_message
.
append
(
gemini_msg
)
return
gemini_message
python/sglang/lang/interpreter.py
View file @
fd7c4792
...
@@ -428,6 +428,7 @@ class StreamExecutor:
...
@@ -428,6 +428,7 @@ class StreamExecutor:
self
.
messages_
.
append
(
last_msg
)
self
.
messages_
.
append
(
last_msg
)
self
.
cur_images
=
[]
self
.
cur_images
=
[]
else
:
else
:
# OpenAI chat API format
self
.
messages_
.
append
({
"role"
:
expr
.
role
,
"content"
:
new_text
})
self
.
messages_
.
append
({
"role"
:
expr
.
role
,
"content"
:
new_text
})
self
.
cur_role
=
None
self
.
cur_role
=
None
...
...
python/sglang/lang/ir.py
View file @
fd7c4792
...
@@ -49,6 +49,16 @@ class SglSamplingParams:
...
@@ -49,6 +49,16 @@ class SglSamplingParams:
"presence_penalty"
:
self
.
presence_penalty
,
"presence_penalty"
:
self
.
presence_penalty
,
}
}
def
to_gemini_kwargs
(
self
):
return
{
"candidate_count"
:
1
,
"max_output_tokens"
:
self
.
max_new_tokens
,
"stop_sequences"
:
self
.
stop
,
"temperature"
:
self
.
temperature
,
"top_p"
:
self
.
top_p
,
"top_k"
:
self
.
top_k
if
self
.
top_k
>
0
else
None
,
}
def
to_anthropic_kwargs
(
self
):
def
to_anthropic_kwargs
(
self
):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
return
{
return
{
...
...
python/sglang/srt/models/mixtral.py
View file @
fd7c4792
...
@@ -355,7 +355,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -355,7 +355,7 @@ class MixtralForCausalLM(nn.Module):
):
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
...
...
python/sglang/test/test_programs.py
View file @
fd7c4792
...
@@ -304,7 +304,10 @@ def test_image_qa():
...
@@ -304,7 +304,10 @@ def test_image_qa():
temperature
=
0
,
temperature
=
0
,
max_new_tokens
=
64
,
max_new_tokens
=
64
,
)
)
assert
"taxi"
in
state
.
messages
()[
-
1
][
"content"
]
assert
(
"taxi"
in
state
.
messages
()[
-
1
][
"content"
]
or
"car"
in
state
.
messages
()[
-
1
][
"content"
]
)
def
test_stream
():
def
test_stream
():
...
...
test/lang/test_gemini_backend.py
0 → 100644
View file @
fd7c4792
import
unittest
from
sglang.test.test_programs
import
(
test_expert_answer
,
test_few_shot_qa
,
test_image_qa
,
test_mt_bench
,
test_parallel_decoding
,
test_parallel_encoding
,
test_stream
,
)
from
sglang
import
Gemini
,
set_default_backend
class
TestGeminiBackend
(
unittest
.
TestCase
):
backend
=
None
chat_backend
=
None
chat_vision_backend
=
None
def
setUp
(
self
):
cls
=
type
(
self
)
if
cls
.
backend
is
None
:
cls
.
backend
=
Gemini
(
"gemini-pro"
)
cls
.
chat_backend
=
Gemini
(
"gemini-pro"
)
cls
.
chat_vision_backend
=
Gemini
(
"gemini-pro-vision"
)
def
test_few_shot_qa
(
self
):
set_default_backend
(
self
.
backend
)
test_few_shot_qa
()
def
test_mt_bench
(
self
):
set_default_backend
(
self
.
chat_backend
)
test_mt_bench
()
def
test_expert_answer
(
self
):
set_default_backend
(
self
.
backend
)
test_expert_answer
()
def
test_parallel_decoding
(
self
):
set_default_backend
(
self
.
backend
)
test_parallel_decoding
()
def
test_parallel_encoding
(
self
):
set_default_backend
(
self
.
backend
)
test_parallel_encoding
()
def
test_image_qa
(
self
):
set_default_backend
(
self
.
chat_vision_backend
)
test_image_qa
()
def
test_stream
(
self
):
set_default_backend
(
self
.
backend
)
test_stream
()
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = TestGeminiBackend()
# t.setUp()
# t.test_stream()
test/lang/test_openai_backend.py
View file @
fd7c4792
...
@@ -88,4 +88,15 @@ if __name__ == "__main__":
...
@@ -88,4 +88,15 @@ if __name__ == "__main__":
# global_config.verbosity = 2
# global_config.verbosity = 2
# t = TestOpenAIBackend()
# t = TestOpenAIBackend()
# t.setUp()
# t.setUp()
# t.test_few_shot_qa()
# t.test_mt_bench()
# t.test_select()
# t.test_decode_int()
# t.test_decode_json()
# t.test_decode_json()
# t.test_expert_answer()
# t.test_tool_use()
# t.test_react()
# t.test_parallel_decoding()
# t.test_parallel_encoding()
# t.test_image_qa()
# t.test_stream()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment