Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0909bb0d
"vscode:/vscode.git/clone" did not exist on "7aa6af1138b206bec10ab3af23a365c0f573b67d"
Unverified
Commit
0909bb0d
authored
Aug 13, 2024
by
Ying Sheng
Committed by
GitHub
Aug 13, 2024
Browse files
[Feat] Add window attention for gemma-2 (#1056)
parent
ad3e4f16
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
319 additions
and
126 deletions
+319
-126
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+40
-19
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+145
-58
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+69
-17
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+11
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+12
-0
python/sglang/test/long_prompt
python/sglang/test/long_prompt
+1
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+16
-10
scripts/playground/reference_hf.py
scripts/playground/reference_hf.py
+4
-4
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+6
-4
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+14
-8
No files found.
python/sglang/bench_latency.py
View file @
0909bb0d
...
...
@@ -64,7 +64,7 @@ class BenchArgs:
run_name
:
str
=
"before"
batch_size
:
Tuple
[
int
]
=
(
1
,)
input_len
:
Tuple
[
int
]
=
(
1024
,)
output_len
:
Tuple
[
int
]
=
(
4
,)
output_len
:
Tuple
[
int
]
=
(
16
,)
result_filename
:
str
=
""
correctness_test
:
bool
=
False
# This is only used for correctness test
...
...
python/sglang/srt/layers/radix_attention.py
View file @
0909bb0d
...
...
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
sliding_window_size
:
int
=
-
1
,
logit_cap
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
):
...
...
@@ -46,6 +47,7 @@ class RadixAttention(nn.Module):
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
sliding_window_size
=
sliding_window_size
if
(
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
...
...
@@ -113,39 +115,51 @@ class RadixAttention(nn.Module):
return
o
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged
=
input_metadata
.
flashinfer_prefill_wrapper_paged
if
self
.
sliding_window_size
!=
-
1
:
prefill_wrapper_ragged
=
prefill_wrapper_ragged
[
0
]
prefill_wrapper_paged
=
prefill_wrapper_paged
[
0
]
else
:
if
isinstance
(
prefill_wrapper_ragged
,
list
):
prefill_wrapper_ragged
=
prefill_wrapper_ragged
[
1
]
if
isinstance
(
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
prefill_wrapper_paged
[
1
]
if
not
input_metadata
.
flashinfer_use_ragged
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
flashinfer_
prefill_wrapper_paged
.
forward
(
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
window_left
=
self
.
sliding_window_size
,
logits_soft_cap
=
self
.
logit_cap
,
)
else
:
o1
,
s1
=
(
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
o1
,
s1
=
prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
window_left
=
self
.
sliding_window_size
,
logits_soft_cap
=
self
.
logit_cap
,
)
if
input_metadata
.
extend_no_prefix
:
o
=
o1
else
:
o2
,
s2
=
(
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
causal
=
False
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
# TODO window attention + radix attention will come up in next PR
assert
self
.
sliding_window_size
==
-
1
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
)
,
causal
=
False
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
...
...
@@ -158,9 +172,16 @@ class RadixAttention(nn.Module):
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
def
decode_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
decode_wrapper
=
input_metadata
.
flashinfer_decode_wrapper
if
self
.
sliding_window_size
!=
-
1
:
decode_wrapper
=
decode_wrapper
[
0
]
else
:
if
isinstance
(
decode_wrapper
,
list
):
decode_wrapper
=
decode_wrapper
[
1
]
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
flashinfer_
decode_wrapper
.
forward
(
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
sm_scale
=
self
.
scaling
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
0909bb0d
...
...
@@ -16,7 +16,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -154,6 +154,7 @@ class InputMetadata:
model_runner
:
"ModelRunner"
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
,
sliding_window_size
:
Optional
[
int
]
=
None
,
):
ret
=
cls
(
forward_mode
=
forward_mode
,
...
...
@@ -197,7 +198,7 @@ class InputMetadata:
):
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
model_runner
,
prefix_lens
,
flashinfer_use_ragged
model_runner
,
prefix_lens
,
flashinfer_use_ragged
,
sliding_window_size
)
return
ret
...
...
@@ -216,7 +217,11 @@ class InputMetadata:
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
def
init_flashinfer_handlers
(
self
,
model_runner
,
prefix_lens
,
flashinfer_use_ragged
self
,
model_runner
,
prefix_lens
,
flashinfer_use_ragged
,
sliding_window_size
=
None
,
):
update_flashinfer_indices
(
self
.
forward_mode
,
...
...
@@ -225,6 +230,7 @@ class InputMetadata:
self
.
seq_lens
,
prefix_lens
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
sliding_window_size
=
sliding_window_size
,
)
(
...
...
@@ -248,6 +254,7 @@ def update_flashinfer_indices(
prefix_lens
,
flashinfer_decode_wrapper
=
None
,
flashinfer_use_ragged
=
False
,
sliding_window_size
=
None
,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
...
...
@@ -255,65 +262,145 @@ def update_flashinfer_indices(
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
batch_size
)
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
sliding_window_size
is
None
:
if
flashinfer_use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
batch_size
)
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
flashinfer_use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
wrapper_id
in
range
(
2
):
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
if
wrapper_id
==
0
and
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
torch
.
minimum
(
paged_kernel_lens
,
torch
.
tensor
(
sliding_window_size
)
)
kv_start_idx
=
seq_lens
-
paged_kernel_lens
else
:
kv_start_idx
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
kv_start_idx
[
i
]
:
kv_start_idx
[
i
]
+
paged_kernel_lens_cpu
[
i
],
]
for
i
in
range
(
batch_size
)
],
dim
=
0
,
).
contiguous
()
if
forward_mode
==
ForwardMode
.
DECODE
:
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
[
wrapper_id
].
end_forward
()
flashinfer_decode_wrapper
[
wrapper_id
].
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
flashinfer_use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
[
wrapper_id
].
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
[
wrapper_id
].
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
].
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
].
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
python/sglang/srt/model_executor/model_runner.py
View file @
0909bb0d
...
...
@@ -295,7 +295,16 @@ class ModelRunner:
return
c
def
init_flashinfer
(
self
):
self
.
sliding_window_size
=
(
self
.
model
.
get_window_size
()
if
hasattr
(
self
.
model
,
"get_window_size"
)
else
None
)
if
self
.
server_args
.
disable_flashinfer
:
assert
(
self
.
sliding_window_size
is
None
),
"turn on flashinfer to support window attention"
self
.
flashinfer_prefill_wrapper_ragged
=
None
self
.
flashinfer_prefill_wrapper_paged
=
None
self
.
flashinfer_decode_wrapper
=
None
...
...
@@ -309,20 +318,54 @@ class ModelRunner:
else
:
use_tensor_cores
=
False
self
.
flashinfer_workspace_buffers
=
torch
.
empty
(
2
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
if
self
.
sliding_window_size
is
None
:
self
.
flashinfer_workspace_buffers
=
torch
.
empty
(
2
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
flashinfer_prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
else
:
workspace_buffers
=
torch
.
empty
(
4
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
flashinfer_prefill_wrapper_ragged
=
[]
self
.
flashinfer_prefill_wrapper_paged
=
[]
self
.
flashinfer_decode_wrapper
=
[]
for
i
in
range
(
2
):
self
.
flashinfer_prefill_wrapper_ragged
.
append
(
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
0
],
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
1
],
"NHD"
)
)
self
.
flashinfer_decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
)
def
init_cuda_graphs
(
self
):
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
...
...
@@ -358,7 +401,10 @@ class ModelRunner:
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
ForwardMode
.
DECODE
self
,
batch
,
ForwardMode
.
DECODE
,
sliding_window_size
=
self
.
sliding_window_size
,
)
return
self
.
model
.
forward
(
...
...
@@ -368,7 +414,10 @@ class ModelRunner:
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
sliding_window_size
=
self
.
sliding_window_size
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -377,7 +426,10 @@ class ModelRunner:
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
sliding_window_size
=
self
.
sliding_window_size
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
...
...
python/sglang/srt/models/gemma2.py
View file @
0909bb0d
...
...
@@ -44,6 +44,12 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def
get_window_size
(
config
):
return
config
.
sliding_window
-
1
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
...
...
@@ -200,17 +206,14 @@ class Gemma2Attention(nn.Module):
dtype
=
torch
.
get_default_dtype
(),
)
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window
=
layer_idx
%
2
==
1
and
config
.
sliding_window
is
not
None
del
use_sliding_window
# Unused.
use_sliding_window
=
layer_idx
%
2
==
0
and
hasattr
(
config
,
"sliding_window"
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_idx
,
sliding_window_size
=
get_window_size
(
config
)
if
use_sliding_window
else
-
1
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
)
...
...
@@ -403,6 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
def
get_window_size
(
self
):
return
get_window_size
(
self
.
config
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
...
...
python/sglang/srt/server_args.py
View file @
0909bb0d
...
...
@@ -17,9 +17,12 @@ limitations under the License.
import
argparse
import
dataclasses
import
logging
import
random
from
typing
import
List
,
Optional
,
Union
logger
=
logging
.
getLogger
(
__name__
)
@
dataclasses
.
dataclass
class
ServerArgs
:
...
...
@@ -446,6 +449,15 @@ class ServerArgs:
assert
not
(
self
.
dp_size
>
1
and
self
.
node_rank
is
not
None
),
"multi-node data parallel is not supported"
if
"gemma-2"
in
self
.
model_path
.
lower
():
logger
.
info
(
f
"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
)
self
.
disable_radix_cache
=
True
self
.
disable_regex_jump_forward
=
True
self
.
disable_flashinfer
=
False
self
.
disable_cuda_graph
=
True
self
.
chunked_prefill_size
=
None
@
dataclasses
.
dataclass
...
...
python/sglang/test/long_prompt
0 → 100644
View file @
0909bb0d
This diff is collapsed.
Click to expand it.
python/sglang/test/runners.py
View file @
0909bb0d
...
...
@@ -15,6 +15,7 @@ limitations under the License.
import
json
import
multiprocessing
import
os
from
dataclasses
import
dataclass
from
typing
import
List
,
Union
...
...
@@ -31,8 +32,14 @@ DEFAULT_PROMPTS = [
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
"AI is a field of computer science focused on"
,
"Apple is red. Banana is Yellow. "
*
800
+
"Apple is"
,
]
dirpath
=
os
.
path
.
dirname
(
__file__
)
with
open
(
os
.
path
.
join
(
dirpath
,
"long_prompt"
),
"r"
)
as
f
:
long_prompt
=
f
.
read
()
DEFAULT_PROMPTS
.
append
(
long_prompt
)
NUM_TOP_LOGPROBS
=
5
...
...
@@ -125,16 +132,14 @@ class HFRunner:
)
logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
).
tolist
()
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
# print("index", index_of_max)
logprobs
=
[
sorted
(
token_logprobs
,
reverse
=
True
)[:
NUM_TOP_LOGPROBS
]
for
token_logprobs
in
logprobs
]
prefill_logprobs
.
append
(
logprobs
)
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
logprobs
,
top_indices
=
torch
.
topk
(
logprobs
,
k
=
NUM_TOP_LOGPROBS
,
dim
=-
1
)
# print("index", top_indices)
prefill_logprobs
.
append
(
logprobs
.
tolist
())
del
logits
del
logprobs
out_queue
.
put
(
ModelOutput
(
...
...
@@ -186,6 +191,7 @@ class SRTRunner:
tp_size
=
tp_size
,
dtype
=
get_dtype_str
(
torch_dtype
),
port
=
port
,
mem_fraction_static
=
0.7
,
)
def
forward
(
...
...
scripts/playground/reference_hf.py
View file @
0909bb0d
...
...
@@ -35,18 +35,17 @@ def normal_text(args):
args
.
model_path
,
torch_dtype
=
torch
.
float16
,
low_cpu_mem_usage
=
True
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
m
.
cuda
()
print
(
m
)
prompts
=
[
"The capital of France is"
,
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
]
max_new_tokens
=
32
max_new_tokens
=
16
for
p
in
prompts
:
if
isinstance
(
p
,
str
):
...
...
@@ -58,10 +57,11 @@ def normal_text(args):
input_ids
,
do_sample
=
False
,
max_new_tokens
=
max_new_tokens
)
output_str
=
t
.
decode
(
output_ids
[
0
])
print
(
output_str
)
prefill_logits
=
m
.
forward
(
input_ids
).
logits
[
0
][
-
1
]
print
(
"prefill logits"
,
prefill_logits
)
print
(
output_str
)
@
torch
.
inference_mode
()
...
...
test/srt/models/test_embedding_models.py
View file @
0909bb0d
...
...
@@ -53,11 +53,13 @@ class TestEmbeddingModels(unittest.TestCase):
srt_logits
=
torch
.
Tensor
(
srt_outputs
.
embed_logits
[
i
])
similarities
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
print
(
"max similarity diff"
,
torch
.
max
(
abs
(
similarities
-
1
)))
tolerance
=
1e-2
assert
torch
.
all
(
abs
(
similarities
-
1
)
<
tolerance
),
f
"embeddings not all close"
if
hf_logits
.
shape
[
0
]
<=
100
:
tolerance
=
1e-2
assert
torch
.
all
(
abs
(
similarities
-
1
)
<
tolerance
),
f
"embeddings not all close"
def
test_prefill_logits
(
self
):
for
model
,
tp_size
in
MODELS
:
...
...
test/srt/models/test_generation_models.py
View file @
0909bb0d
...
...
@@ -20,8 +20,8 @@ import torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
MODELS
=
[
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
),
(
"google/gemma-2-2b"
,
1
),
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
,
1.
1
),
(
"google/gemma-2-2b"
,
1
,
3
),
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
@@ -35,6 +35,7 @@ class TestGenerationModels(unittest.TestCase):
tp_size
,
torch_dtype
,
max_new_tokens
,
long_context_tolerance
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
True
...
...
@@ -53,15 +54,19 @@ class TestGenerationModels(unittest.TestCase):
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
tolerance
=
3e-2
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tolerance
),
f
"prefill logprobs not all close"
print
(
"max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
)))
if
hf_logprobs
.
shape
[
0
]
<=
100
:
tolerance
=
3e-2
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tolerance
),
f
"prefill logprobs not all close"
print
(
hf_outputs
.
output_strs
)
print
(
srt_outputs
.
output_strs
)
assert
hf_outputs
.
output_strs
==
srt_outputs
.
output_strs
def
test_prefill_logits
(
self
):
for
model
,
tp_size
in
MODELS
:
def
test_prefill_logits
_and_output_strs
(
self
):
for
model
,
tp_size
,
long_context_tolerance
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
max_new_tokens
=
8
self
.
assert_close_prefill_logits_and_output_strs
(
...
...
@@ -70,6 +75,7 @@ class TestGenerationModels(unittest.TestCase):
tp_size
,
torch_dtype
,
max_new_tokens
,
long_context_tolerance
=
long_context_tolerance
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment