Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
89885b31
"docs/vscode:/vscode.git/clone" did not exist on "058f4bd7d523a89cc7c920a22df4b1e1150d5f01"
Unverified
Commit
89885b31
authored
Mar 11, 2024
by
Liangsheng Yin
Committed by
GitHub
Mar 11, 2024
Browse files
Gemma Support (#256)
parent
64fe3115
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
428 additions
and
55 deletions
+428
-55
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+32
-4
python/sglang/srt/layers/context_flashattention_nopad.py
python/sglang/srt/layers/context_flashattention_nopad.py
+1
-1
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+7
-6
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-8
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+12
-4
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+13
-12
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+8
-15
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+7
-5
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+340
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/lang/chat_template.py
View file @
89885b31
...
...
@@ -17,15 +17,23 @@ class ChatTemplate:
image_token
:
str
=
"<image>"
style
:
ChatTemplateStyle
=
ChatTemplateStyle
.
PLAIN
def
get_prefix_and_suffix
(
self
,
role
:
str
,
hist_messages
:
List
[
Dict
])
->
Tuple
[
str
,
str
]:
def
get_prefix_and_suffix
(
self
,
role
:
str
,
hist_messages
:
List
[
Dict
]
)
->
Tuple
[
str
,
str
]:
prefix
,
suffix
=
self
.
role_prefix_and_suffix
.
get
(
role
,
(
""
,
""
))
if
self
.
style
==
ChatTemplateStyle
.
LLAMA2
:
if
role
==
"system"
and
not
hist_messages
:
user_prefix
,
_
=
self
.
role_prefix_and_suffix
.
get
(
"user"
,
(
""
,
""
))
system_prefix
,
system_suffix
=
self
.
role_prefix_and_suffix
.
get
(
"system"
,
(
""
,
""
))
system_prefix
,
system_suffix
=
self
.
role_prefix_and_suffix
.
get
(
"system"
,
(
""
,
""
)
)
return
(
user_prefix
+
system_prefix
,
system_suffix
)
elif
role
==
"user"
and
len
(
hist_messages
)
==
1
and
hist_messages
[
0
][
"content"
]
is
not
None
:
elif
(
role
==
"user"
and
len
(
hist_messages
)
==
1
and
hist_messages
[
0
][
"content"
]
is
not
None
):
return
(
""
,
suffix
)
return
prefix
,
suffix
...
...
@@ -171,6 +179,19 @@ register_chat_template(
)
)
register_chat_template
(
ChatTemplate
(
name
=
"gemma-it"
,
default_system_prompt
=
None
,
role_prefix_and_suffix
=
{
"system"
:
(
""
,
""
),
"user"
:
(
"<start_of_turn>user
\n
"
,
"<end_of_turn>
\n
"
),
"assistant"
:
(
"<start_of_turn>model
\n
"
,
"<end_of_turn>
\n
"
),
},
style
=
ChatTemplateStyle
.
PLAIN
,
)
)
@
register_chat_template_matching_function
def
match_vicuna
(
model_path
:
str
):
...
...
@@ -211,6 +232,13 @@ def match_chat_yi(model_path: str):
return
get_chat_template
(
"yi"
)
@
register_chat_template_matching_function
def
match_gemma_it
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"gemma"
in
model_path
and
"it"
in
model_path
:
return
get_chat_template
(
"gemma-it"
)
if
__name__
==
"__main__"
:
messages
=
[
{
"role"
:
"system"
,
"content"
:
None
},
# None means default
...
...
python/sglang/srt/layers/context_flashattention_nopad.py
View file @
89885b31
...
...
@@ -129,7 +129,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
}
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
}
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
...
...
python/sglang/srt/layers/extend_attention.py
View file @
89885b31
...
...
@@ -181,19 +181,20 @@ def extend_attention_fwd(
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
if
CUDA_CAPABILITY
[
0
]
>=
8
:
BLOCK_M
,
BLOCK_N
=
128
,
128
else
:
BLOCK_M
,
BLOCK_N
=
64
,
64
Lq
,
Lk
,
Lv
,
Lo
=
(
q_extend
.
shape
[
-
1
],
k_extend
.
shape
[
-
1
],
v_extend
.
shape
[
-
1
],
o_extend
.
shape
[
-
1
],
)
assert
Lq
==
Lk
and
Lk
==
Lv
and
Lv
==
Lo
assert
Lq
in
{
16
,
32
,
64
,
128
}
assert
Lq
in
{
16
,
32
,
64
,
128
,
256
}
if
CUDA_CAPABILITY
[
0
]
>=
8
:
BLOCK_M
,
BLOCK_N
=
(
128
,
128
)
if
Lq
<=
128
else
(
64
,
64
)
else
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
if
Lq
<=
128
else
(
32
,
32
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch_size
,
head_num
=
b_seq_len
.
shape
[
0
],
q_extend
.
shape
[
1
]
...
...
python/sglang/srt/layers/radix_attention.py
View file @
89885b31
from
typing
import
List
import
torch
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.managers.router.model_runner
import
ForwardMode
,
InputMetadata
from
torch
import
nn
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
class
RadixAttention
(
nn
.
Module
):
...
...
@@ -21,9 +15,9 @@ class RadixAttention(nn.Module):
self
.
head_dim
=
head_dim
self
.
layer_id
=
layer_id
from
sglang.srt.managers.router.model_runner
import
global_
model_mode
from
sglang.srt.managers.router.model_runner
import
global_
server_args
self
.
use_flashinfer
=
"flashinfer"
in
global_model_mode
self
.
use_flashinfer
=
"flashinfer"
in
global_
server_args
.
model_mode
if
self
.
use_flashinfer
:
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
...
...
python/sglang/srt/layers/token_attention.py
View file @
89885b31
...
...
@@ -5,6 +5,14 @@ import torch
import
triton
import
triton.language
as
tl
from
sglang.srt.utils
import
wrap_kernel_launcher
from
sglang.srt.managers.router.model_runner
import
global_server_args
if
global_server_args
.
attention_reduce_in_fp32
:
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TORCH_TYPE
=
torch
.
float32
else
:
REDUCE_TRITON_TYPE
=
tl
.
float16
REDUCE_TORCH_TYPE
=
torch
.
float16
@
triton
.
jit
...
...
@@ -49,7 +57,7 @@ def _fwd_kernel_stage1(
block_mask
=
tl
.
where
(
block_stard_index
<
cur_batch_seq_len
,
1
,
0
)
for
start_mark
in
range
(
0
,
block_mask
,
1
):
q
=
tl
.
load
(
Q
+
off_q
+
start_mark
)
q
=
tl
.
load
(
Q
+
off_q
+
start_mark
)
.
to
(
REDUCE_TRITON_TYPE
)
offs_n_new
=
cur_batch_start_index
+
offs_n
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n_new
,
...
...
@@ -65,7 +73,7 @@ def _fwd_kernel_stage1(
K_Buffer
+
offs_buf_k
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
,
)
)
.
to
(
REDUCE_TRITON_TYPE
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
...
...
@@ -161,7 +169,7 @@ def _token_att_m_fwd(
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
}
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
}
sm_scale
=
1.0
/
(
Lk
**
0.5
)
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -299,7 +307,7 @@ def token_attention_fwd(
):
if
att_m
is
None
:
att_m
=
torch
.
empty
(
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
q
.
dtype
,
device
=
"cuda"
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
)
_token_att_m_fwd
(
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
89885b31
...
...
@@ -57,17 +57,19 @@ class ModelRpcServer(rpyc.Service):
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
server_args
.
model_path
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
)
self
.
model_runner
=
ModelRunner
(
self
.
model_config
,
server_args
.
mem_fraction_static
,
tp_rank
,
server_args
.
tp_size
,
port_args
.
nccl_port
,
server_args
.
load_format
,
server_args
.
trust_remote_code
,
server_args
.
model_m
ode
,
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
nccl_port
=
port_args
.
nccl_port
,
server_args
=
server_args
,
load_format
=
server_args
.
load_format
,
trust_remote_code
=
server_args
.
trust_remote_c
ode
,
)
if
is_multimodal_model
(
server_args
.
model_path
):
self
.
processor
=
get_processor
(
...
...
@@ -435,7 +437,7 @@ class ModelRpcServer(rpyc.Service):
# If logprob_start_len > 0, then first logprob_start_len prompt tokens
# will be ignored.
prompt_token_len
=
len
(
req
.
logprob
)
token_ids
=
req
.
input_ids
[
-
prompt_token_len
:]
+
[
next_token_ids
[
i
]]
token_ids
=
req
.
input_ids
[
-
prompt_token_len
:]
+
[
next_token_ids
[
i
]]
token_logprobs
=
req
.
logprob
+
[
last_logprobs
[
i
]]
req
.
token_logprob
=
list
(
zip
(
token_ids
,
token_logprobs
))
if
req
.
logprob_start_len
==
0
:
...
...
@@ -553,8 +555,7 @@ class ModelRpcServer(rpyc.Service):
"completion_tokens"
:
len
(
req
.
input_ids
)
+
len
(
req
.
output_ids
)
-
req
.
prompt_tokens
,
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
}
if
req
.
return_logprob
:
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
89885b31
...
...
@@ -3,7 +3,6 @@ import logging
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
pathlib
import
Path
from
typing
import
List
import
numpy
as
np
import
torch
...
...
@@ -23,8 +22,8 @@ QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
logger
=
logging
.
getLogger
(
"model_runner"
)
# for
model_mode
global_
model_mode
:
List
[
str
]
=
[]
# for
server args in model endpoints
global_
server_args
=
None
@
lru_cache
()
...
...
@@ -81,7 +80,6 @@ class InputMetadata:
return_logprob
:
bool
=
False
# for flashinfer
use_flashinfer
:
bool
=
False
qo_indptr
:
torch
.
Tensor
=
None
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
...
...
@@ -224,8 +222,7 @@ class InputMetadata:
if
forward_mode
==
ForwardMode
.
EXTEND
:
ret
.
init_extend_args
()
ret
.
use_flashinfer
=
"flashinfer"
in
model_runner
.
model_mode
if
ret
.
use_flashinfer
:
if
"flashinfer"
in
global_server_args
.
model_mode
:
ret
.
init_flashinfer_args
(
tp_size
)
return
ret
...
...
@@ -239,9 +236,9 @@ class ModelRunner:
tp_rank
,
tp_size
,
nccl_port
,
server_args
,
load_format
=
"auto"
,
trust_remote_code
=
True
,
model_mode
:
List
[
str
]
=
(),
):
self
.
model_config
=
model_config
self
.
mem_fraction_static
=
mem_fraction_static
...
...
@@ -250,10 +247,9 @@ class ModelRunner:
self
.
nccl_port
=
nccl_port
self
.
load_format
=
load_format
self
.
trust_remote_code
=
trust_remote_code
self
.
model_mode
=
model_mode
global
global_
model_mode
global_
model_mode
=
model_mode
global
global_
server_args
global_
server_args
=
server_args
# Init torch distributed
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
...
...
@@ -319,9 +315,7 @@ class ModelRunner:
available_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
*
(
1
<<
30
)
head_dim
=
(
self
.
model_config
.
hidden_size
//
self
.
model_config
.
num_attention_heads
)
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
num_key_value_heads
//
self
.
tp_size
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
...
...
@@ -346,8 +340,7 @@ class ModelRunner:
self
.
max_total_num_token
,
dtype
=
torch
.
float16
,
head_num
=
self
.
model_config
.
num_key_value_heads
//
self
.
tp_size
,
head_dim
=
self
.
model_config
.
hidden_size
//
self
.
model_config
.
num_attention_heads
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
...
...
python/sglang/srt/model_config.py
View file @
89885b31
import
os
from
typing
import
Optional
,
Union
from
typing
import
Optional
import
torch
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
...
...
@@ -17,14 +15,18 @@ class ModelConfig:
self
.
trust_remote_code
=
trust_remote_code
self
.
revision
=
revision
self
.
hf_config
=
get_config
(
self
.
path
,
trust_remote_code
,
revision
)
if
context_length
is
not
None
:
self
.
context_len
=
context_length
else
:
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
# Unify the config keys for hf_config
self
.
head_dim
=
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
self
.
head_dim
=
getattr
(
self
.
hf_config
,
"head_dim"
,
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
,
)
self
.
num_attention_heads
=
self
.
hf_config
.
num_attention_heads
self
.
num_key_value_heads
=
getattr
(
self
.
hf_config
,
"num_key_value_heads"
,
None
)
if
self
.
num_key_value_heads
is
None
:
...
...
python/sglang/srt/models/gemma.py
0 → 100644
View file @
89885b31
# Adapted from:
# https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py
"""Inference-only Gemma model compatible with HuggingFace weights."""
from
typing
import
Optional
,
Tuple
import
torch
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
torch
import
nn
from
transformers
import
GemmaConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
class
GemmaMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
act_fn
=
GeluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
GemmaAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
layer_id
:
int
=
0
,
max_position_embeddings
:
int
=
8192
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
GemmaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GemmaConfig
,
layer_id
:
int
=
0
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
GemmaAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
head_dim
=
config
.
head_dim
,
layer_id
=
layer_id
,
max_position_embeddings
=
config
.
max_position_embeddings
,
rope_theta
=
config
.
rope_theta
,
linear_method
=
linear_method
,
)
self
.
mlp
=
GemmaMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
linear_method
=
linear_method
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
GemmaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
(
[
GemmaDecoderLayer
(
config
,
i
,
linear_method
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
skip_embed
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
skip_embed
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_ids
# Normalize the embedding by sqrt(hidden_size)
hidden_states
*=
self
.
config
.
hidden_size
**
0.5
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
GemmaForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
config
:
GemmaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
GemmaModel
(
config
,
linear_method
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
skip_embed
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
skip_embed
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if
"norm.weight"
in
name
:
loaded_weight
+=
1.0
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
raise
RuntimeError
(
"Some weights are not initialized from checkpoints: "
f
"
{
unloaded_params
}
"
)
EntryClass
=
GemmaForCausalLM
python/sglang/srt/server_args.py
View file @
89885b31
...
...
@@ -28,6 +28,7 @@ class ServerArgs:
log_level
:
str
=
"info"
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
...
...
@@ -189,6 +190,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
parser
.
add_argument
(
"--attention-reduce-in-fp32"
,
action
=
"store_true"
,
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment