Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
faba293a
Unverified
Commit
faba293a
authored
Mar 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Mar 11, 2024
Browse files
Improve gemma and documentations (#278)
parent
89885b31
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
56 additions
and
35 deletions
+56
-35
README.md
README.md
+5
-0
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+23
-9
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-1
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+9
-9
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+1
-3
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-6
python/sglang/utils.py
python/sglang/utils.py
+3
-1
test/srt/model/bench_llama_low_api.py
test/srt/model/bench_llama_low_api.py
+6
-6
No files found.
README.md
View file @
faba293a
...
...
@@ -369,8 +369,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
-
Mistral
-
Mixtral
-
Qwen / Qwen 2
-
Gemma
-
Please add a new flag
`--attention-reduce-in-fp32`
to avoid some precision errors.
-
`python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32`
-
LLaVA
-
`python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
-
`python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
-
`python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 3000`
-
Yi-VL
-
see
[
srt_example_yi_vl.py
](
examples/quick_start/srt_example_yi_vl.py
)
.
-
AWQ/GPTQ quantization
...
...
python/sglang/backend/runtime_endpoint.py
View file @
faba293a
...
...
@@ -21,7 +21,9 @@ class RuntimeEndpoint(BaseBackend):
self
.
verify
=
verify
res
=
http_request
(
self
.
base_url
+
"/get_model_info"
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
self
.
base_url
+
"/get_model_info"
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
self
.
model_info
=
res
.
json
()
...
...
@@ -41,7 +43,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -50,7 +52,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -58,7 +60,10 @@ class RuntimeEndpoint(BaseBackend):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -90,7 +95,10 @@ class RuntimeEndpoint(BaseBackend):
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
obj
=
res
.
json
()
comp
=
obj
[
"text"
]
...
...
@@ -129,7 +137,7 @@ class RuntimeEndpoint(BaseBackend):
json
=
data
,
stream
=
True
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
verify
=
self
.
verify
,
)
pos
=
0
...
...
@@ -161,7 +169,10 @@ class RuntimeEndpoint(BaseBackend):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
...
...
@@ -175,7 +186,10 @@ class RuntimeEndpoint(BaseBackend):
}
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
obj
=
res
.
json
()
...
...
@@ -192,7 +206,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/concate_and_append_request"
,
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
python/sglang/srt/layers/token_attention.py
View file @
faba293a
...
...
@@ -4,8 +4,8 @@
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.utils
import
wrap_kernel_launcher
from
sglang.srt.managers.router.model_runner
import
global_server_args
from
sglang.srt.utils
import
wrap_kernel_launcher
if
global_server_args
.
attention_reduce_in_fp32
:
REDUCE_TRITON_TYPE
=
tl
.
float32
...
...
python/sglang/srt/models/gemma.py
View file @
faba293a
...
...
@@ -7,7 +7,7 @@ import torch
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
torch
import
nn
from
transformers
import
Gemma
Config
from
transformers
import
Pretrained
Config
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
...
...
@@ -136,7 +136,7 @@ class GemmaAttention(nn.Module):
class
GemmaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Gemma
Config
,
config
:
Pretrained
Config
,
layer_id
:
int
=
0
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
...
...
@@ -190,7 +190,7 @@ class GemmaDecoderLayer(nn.Module):
class
GemmaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Gemma
Config
,
config
:
Pretrained
Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -213,12 +213,12 @@ class GemmaModel(nn.Module):
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
skip
_embed
:
bool
=
Fals
e
,
input
_embed
s
:
torch
.
Tensor
=
Non
e
,
)
->
torch
.
Tensor
:
if
not
skip_embed
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_
i
ds
hidden_states
=
input_
embe
ds
# Normalize the embedding by sqrt(hidden_size)
hidden_states
*=
self
.
config
.
hidden_size
**
0.5
...
...
@@ -262,7 +262,7 @@ class GemmaForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
Gemma
Config
,
config
:
Pretrained
Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
...
...
@@ -279,9 +279,9 @@ class GemmaForCausalLM(nn.Module):
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
skip
_embed
:
bool
=
Fals
e
,
input
_embed
s
:
torch
.
Tensor
=
Non
e
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
skip
_embed
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input
_embed
s
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
...
...
python/sglang/srt/models/llava.py
View file @
faba293a
...
...
@@ -233,9 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module):
input_ids
,
positions
,
input_metadata
,
input_embeds
=
input_embeds
)
elif
input_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
language_model
(
input_ids
,
positions
,
input_metadata
)
def
load_weights
(
self
,
...
...
python/sglang/srt/server.py
View file @
faba293a
...
...
@@ -550,6 +550,7 @@ class Runtime:
tp_size
:
int
=
1
,
model_mode
:
List
[
str
]
=
(),
schedule_heuristic
:
str
=
"lpm"
,
attention_reduce_in_fp32
:
bool
=
False
,
random_seed
:
int
=
42
,
log_level
:
str
=
"error"
,
port
:
Optional
[
int
]
=
None
,
...
...
@@ -572,6 +573,7 @@ class Runtime:
tp_size
=
tp_size
,
model_mode
=
model_mode
,
schedule_heuristic
=
schedule_heuristic
,
attention_reduce_in_fp32
=
attention_reduce_in_fp32
,
random_seed
=
random_seed
,
log_level
=
log_level
,
)
...
...
python/sglang/srt/server_args.py
View file @
faba293a
...
...
@@ -21,6 +21,7 @@ class ServerArgs:
model_mode
:
List
[
str
]
=
()
schedule_heuristic
:
str
=
"lpm"
schedule_conservativeness
:
float
=
1.0
attention_reduce_in_fp32
:
bool
=
False
random_seed
:
int
=
42
stream_interval
:
int
=
8
disable_log_stats
:
bool
=
False
...
...
@@ -28,7 +29,6 @@ class ServerArgs:
log_level
:
str
=
"info"
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
...
...
@@ -157,6 +157,11 @@ class ServerArgs:
default
=
ServerArgs
.
random_seed
,
help
=
"Random seed."
,
)
parser
.
add_argument
(
"--attention-reduce-in-fp32"
,
action
=
"store_true"
,
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
,
)
parser
.
add_argument
(
"--stream-interval"
,
type
=
int
,
...
...
@@ -190,11 +195,6 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
parser
.
add_argument
(
"--attention-reduce-in-fp32"
,
action
=
"store_true"
,
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/utils.py
View file @
faba293a
...
...
@@ -97,7 +97,9 @@ def http_request(url, json=None, stream=False, auth_token=None, verify=None):
"Content-Type"
:
"application/json"
,
"Authentication"
:
f
"Bearer
{
auth_token
}
"
,
}
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
,
verify
=
verify
)
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
,
verify
=
verify
)
else
:
req
=
urllib
.
request
.
Request
(
url
)
req
.
add_header
(
"Content-Type"
,
"application/json; charset=utf-8"
)
...
...
test/srt/model/bench_llama_low_api.py
View file @
faba293a
...
...
@@ -66,9 +66,9 @@ class BenchBatch:
p_idx
=
prefix_req_idx
[
i
//
fork_num
].
item
()
n_idx
=
self
.
req_pool_indices
[
i
].
item
()
req_to_token
[
n_idx
,
:
prefix_len
]
=
req_to_token
[
p_idx
,
:
prefix_len
]
req_to_token
[
n_idx
,
prefix_len
:
prefix_len
+
extend_len
]
=
(
self
.
out_cache_loc
[
i
*
extend_len
:
(
i
+
1
)
*
extend_len
]
)
req_to_token
[
n_idx
,
prefix_len
:
prefix_len
+
extend_len
]
=
self
.
out_cache_loc
[
i
*
extend_len
:
(
i
+
1
)
*
extend_len
]
def
update_decode
(
self
,
predict_ids
,
batch_size
):
assert
predict_ids
.
shape
[
0
]
==
batch_size
...
...
@@ -81,9 +81,9 @@ class BenchBatch:
self
.
out_cache_cont_start
,
self
.
out_cache_cont_end
,
)
=
self
.
token_to_kv_pool
.
alloc_contiguous
(
batch_size
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
]
=
(
self
.
out_cache_loc
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
]
=
self
.
out_cache_loc
self
.
seq_lens
.
add_
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment