Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0909bb0d
Unverified
Commit
0909bb0d
authored
Aug 13, 2024
by
Ying Sheng
Committed by
GitHub
Aug 13, 2024
Browse files
[Feat] Add window attention for gemma-2 (#1056)
parent
ad3e4f16
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
319 additions
and
126 deletions
+319
-126
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+40
-19
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+145
-58
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+69
-17
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+11
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+12
-0
python/sglang/test/long_prompt
python/sglang/test/long_prompt
+1
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+16
-10
scripts/playground/reference_hf.py
scripts/playground/reference_hf.py
+4
-4
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+6
-4
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+14
-8
No files found.
python/sglang/bench_latency.py
View file @
0909bb0d
...
@@ -64,7 +64,7 @@ class BenchArgs:
...
@@ -64,7 +64,7 @@ class BenchArgs:
run_name
:
str
=
"before"
run_name
:
str
=
"before"
batch_size
:
Tuple
[
int
]
=
(
1
,)
batch_size
:
Tuple
[
int
]
=
(
1
,)
input_len
:
Tuple
[
int
]
=
(
1024
,)
input_len
:
Tuple
[
int
]
=
(
1024
,)
output_len
:
Tuple
[
int
]
=
(
4
,)
output_len
:
Tuple
[
int
]
=
(
16
,)
result_filename
:
str
=
""
result_filename
:
str
=
""
correctness_test
:
bool
=
False
correctness_test
:
bool
=
False
# This is only used for correctness test
# This is only used for correctness test
...
...
python/sglang/srt/layers/radix_attention.py
View file @
0909bb0d
...
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
...
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
scaling
:
float
,
scaling
:
float
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
,
layer_id
:
int
,
sliding_window_size
:
int
=
-
1
,
logit_cap
:
int
=
-
1
,
logit_cap
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
):
):
...
@@ -46,6 +47,7 @@ class RadixAttention(nn.Module):
...
@@ -46,6 +47,7 @@ class RadixAttention(nn.Module):
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
sliding_window_size
=
sliding_window_size
if
(
if
(
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
...
@@ -113,40 +115,52 @@ class RadixAttention(nn.Module):
...
@@ -113,40 +115,52 @@ class RadixAttention(nn.Module):
return
o
return
o
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged
=
input_metadata
.
flashinfer_prefill_wrapper_paged
if
self
.
sliding_window_size
!=
-
1
:
prefill_wrapper_ragged
=
prefill_wrapper_ragged
[
0
]
prefill_wrapper_paged
=
prefill_wrapper_paged
[
0
]
else
:
if
isinstance
(
prefill_wrapper_ragged
,
list
):
prefill_wrapper_ragged
=
prefill_wrapper_ragged
[
1
]
if
isinstance
(
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
prefill_wrapper_paged
[
1
]
if
not
input_metadata
.
flashinfer_use_ragged
:
if
not
input_metadata
.
flashinfer_use_ragged
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
flashinfer_
prefill_wrapper_paged
.
forward
(
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
causal
=
True
,
causal
=
True
,
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
window_left
=
self
.
sliding_window_size
,
logits_soft_cap
=
self
.
logit_cap
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
else
:
else
:
o1
,
s1
=
(
o1
,
s1
=
prefill_wrapper_ragged
.
forward_return_lse
(
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
causal
=
True
,
causal
=
True
,
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
window_left
=
self
.
sliding_window_size
,
logits_soft_cap
=
self
.
logit_cap
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
)
if
input_metadata
.
extend_no_prefix
:
if
input_metadata
.
extend_no_prefix
:
o
=
o1
o
=
o1
else
:
else
:
o2
,
s2
=
(
# TODO window attention + radix attention will come up in next PR
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
assert
self
.
sliding_window_size
==
-
1
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
causal
=
False
,
causal
=
False
,
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
...
@@ -158,9 +172,16 @@ class RadixAttention(nn.Module):
...
@@ -158,9 +172,16 @@ class RadixAttention(nn.Module):
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
def
decode_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
decode_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
decode_wrapper
=
input_metadata
.
flashinfer_decode_wrapper
if
self
.
sliding_window_size
!=
-
1
:
decode_wrapper
=
decode_wrapper
[
0
]
else
:
if
isinstance
(
decode_wrapper
,
list
):
decode_wrapper
=
decode_wrapper
[
1
]
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
input_metadata
.
flashinfer_
decode_wrapper
.
forward
(
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
0909bb0d
...
@@ -16,7 +16,7 @@ limitations under the License.
...
@@ -16,7 +16,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -154,6 +154,7 @@ class InputMetadata:
...
@@ -154,6 +154,7 @@ class InputMetadata:
model_runner
:
"ModelRunner"
,
model_runner
:
"ModelRunner"
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
sliding_window_size
:
Optional
[
int
]
=
None
,
):
):
ret
=
cls
(
ret
=
cls
(
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
...
@@ -197,7 +198,7 @@ class InputMetadata:
...
@@ -197,7 +198,7 @@ class InputMetadata:
):
):
flashinfer_use_ragged
=
True
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
ret
.
init_flashinfer_handlers
(
model_runner
,
prefix_lens
,
flashinfer_use_ragged
model_runner
,
prefix_lens
,
flashinfer_use_ragged
,
sliding_window_size
)
)
return
ret
return
ret
...
@@ -216,7 +217,11 @@ class InputMetadata:
...
@@ -216,7 +217,11 @@ class InputMetadata:
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
def
init_flashinfer_handlers
(
def
init_flashinfer_handlers
(
self
,
model_runner
,
prefix_lens
,
flashinfer_use_ragged
self
,
model_runner
,
prefix_lens
,
flashinfer_use_ragged
,
sliding_window_size
=
None
,
):
):
update_flashinfer_indices
(
update_flashinfer_indices
(
self
.
forward_mode
,
self
.
forward_mode
,
...
@@ -225,6 +230,7 @@ class InputMetadata:
...
@@ -225,6 +230,7 @@ class InputMetadata:
self
.
seq_lens
,
self
.
seq_lens
,
prefix_lens
,
prefix_lens
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
sliding_window_size
=
sliding_window_size
,
)
)
(
(
...
@@ -248,6 +254,7 @@ def update_flashinfer_indices(
...
@@ -248,6 +254,7 @@ def update_flashinfer_indices(
prefix_lens
,
prefix_lens
,
flashinfer_decode_wrapper
=
None
,
flashinfer_decode_wrapper
=
None
,
flashinfer_use_ragged
=
False
,
flashinfer_use_ragged
=
False
,
sliding_window_size
=
None
,
):
):
"""Init auxiliary variables for FlashInfer attention backend."""
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
...
@@ -255,6 +262,7 @@ def update_flashinfer_indices(
...
@@ -255,6 +262,7 @@ def update_flashinfer_indices(
head_dim
=
model_runner
.
model_config
.
head_dim
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
if
sliding_window_size
is
None
:
if
flashinfer_use_ragged
:
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
else
:
else
:
...
@@ -317,3 +325,82 @@ def update_flashinfer_indices(
...
@@ -317,3 +325,82 @@ def update_flashinfer_indices(
head_dim
,
head_dim
,
1
,
1
,
)
)
else
:
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
wrapper_id
in
range
(
2
):
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
if
wrapper_id
==
0
and
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
torch
.
minimum
(
paged_kernel_lens
,
torch
.
tensor
(
sliding_window_size
)
)
kv_start_idx
=
seq_lens
-
paged_kernel_lens
else
:
kv_start_idx
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
kv_start_idx
[
i
]
:
kv_start_idx
[
i
]
+
paged_kernel_lens_cpu
[
i
],
]
for
i
in
range
(
batch_size
)
],
dim
=
0
,
).
contiguous
()
if
forward_mode
==
ForwardMode
.
DECODE
:
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
[
wrapper_id
].
end_forward
()
flashinfer_decode_wrapper
[
wrapper_id
].
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
flashinfer_use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
[
wrapper_id
].
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
[
wrapper_id
].
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
].
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
].
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
python/sglang/srt/model_executor/model_runner.py
View file @
0909bb0d
...
@@ -295,7 +295,16 @@ class ModelRunner:
...
@@ -295,7 +295,16 @@ class ModelRunner:
return
c
return
c
def
init_flashinfer
(
self
):
def
init_flashinfer
(
self
):
self
.
sliding_window_size
=
(
self
.
model
.
get_window_size
()
if
hasattr
(
self
.
model
,
"get_window_size"
)
else
None
)
if
self
.
server_args
.
disable_flashinfer
:
if
self
.
server_args
.
disable_flashinfer
:
assert
(
self
.
sliding_window_size
is
None
),
"turn on flashinfer to support window attention"
self
.
flashinfer_prefill_wrapper_ragged
=
None
self
.
flashinfer_prefill_wrapper_ragged
=
None
self
.
flashinfer_prefill_wrapper_paged
=
None
self
.
flashinfer_prefill_wrapper_paged
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_decode_wrapper
=
None
...
@@ -309,12 +318,18 @@ class ModelRunner:
...
@@ -309,12 +318,18 @@ class ModelRunner:
else
:
else
:
use_tensor_cores
=
False
use_tensor_cores
=
False
if
self
.
sliding_window_size
is
None
:
self
.
flashinfer_workspace_buffers
=
torch
.
empty
(
self
.
flashinfer_workspace_buffers
=
torch
.
empty
(
2
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
2
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
self
.
flashinfer_workspace_buffers
[
0
],
"NHD"
)
)
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
1
],
"NHD"
self
.
flashinfer_workspace_buffers
[
1
],
"NHD"
)
)
...
@@ -323,6 +338,34 @@ class ModelRunner:
...
@@ -323,6 +338,34 @@ class ModelRunner:
"NHD"
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
use_tensor_cores
=
use_tensor_cores
,
)
)
else
:
workspace_buffers
=
torch
.
empty
(
4
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
flashinfer_prefill_wrapper_ragged
=
[]
self
.
flashinfer_prefill_wrapper_paged
=
[]
self
.
flashinfer_decode_wrapper
=
[]
for
i
in
range
(
2
):
self
.
flashinfer_prefill_wrapper_ragged
.
append
(
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
0
],
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
1
],
"NHD"
)
)
self
.
flashinfer_decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
)
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
...
@@ -358,7 +401,10 @@ class ModelRunner:
...
@@ -358,7 +401,10 @@ class ModelRunner:
return
self
.
cuda_graph_runner
.
replay
(
batch
)
return
self
.
cuda_graph_runner
.
replay
(
batch
)
input_metadata
=
InputMetadata
.
from_schedule_batch
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
ForwardMode
.
DECODE
self
,
batch
,
ForwardMode
.
DECODE
,
sliding_window_size
=
self
.
sliding_window_size
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
...
@@ -368,7 +414,10 @@ class ModelRunner:
...
@@ -368,7 +414,10 @@ class ModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
def
forward_extend
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
sliding_window_size
=
self
.
sliding_window_size
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
@@ -377,7 +426,10 @@ class ModelRunner:
...
@@ -377,7 +426,10 @@ class ModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
sliding_window_size
=
self
.
sliding_window_size
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
batch
.
input_ids
,
...
...
python/sglang/srt/models/gemma2.py
View file @
0909bb0d
...
@@ -44,6 +44,12 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -44,6 +44,12 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def
get_window_size
(
config
):
return
config
.
sliding_window
-
1
class
GemmaRMSNorm
(
CustomOp
):
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
"""RMS normalization for Gemma.
...
@@ -200,17 +206,14 @@ class Gemma2Attention(nn.Module):
...
@@ -200,17 +206,14 @@ class Gemma2Attention(nn.Module):
dtype
=
torch
.
get_default_dtype
(),
dtype
=
torch
.
get_default_dtype
(),
)
)
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
use_sliding_window
=
layer_idx
%
2
==
0
and
hasattr
(
config
,
"sliding_window"
)
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window
=
layer_idx
%
2
==
1
and
config
.
sliding_window
is
not
None
del
use_sliding_window
# Unused.
self
.
attn
=
RadixAttention
(
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_idx
,
layer_id
=
layer_idx
,
sliding_window_size
=
get_window_size
(
config
)
if
use_sliding_window
else
-
1
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
)
)
...
@@ -403,6 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -403,6 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
def
get_window_size
(
self
):
return
get_window_size
(
self
.
config
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
...
...
python/sglang/srt/server_args.py
View file @
0909bb0d
...
@@ -17,9 +17,12 @@ limitations under the License.
...
@@ -17,9 +17,12 @@ limitations under the License.
import
argparse
import
argparse
import
dataclasses
import
dataclasses
import
logging
import
random
import
random
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
logger
=
logging
.
getLogger
(
__name__
)
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ServerArgs
:
class
ServerArgs
:
...
@@ -446,6 +449,15 @@ class ServerArgs:
...
@@ -446,6 +449,15 @@ class ServerArgs:
assert
not
(
assert
not
(
self
.
dp_size
>
1
and
self
.
node_rank
is
not
None
self
.
dp_size
>
1
and
self
.
node_rank
is
not
None
),
"multi-node data parallel is not supported"
),
"multi-node data parallel is not supported"
if
"gemma-2"
in
self
.
model_path
.
lower
():
logger
.
info
(
f
"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
)
self
.
disable_radix_cache
=
True
self
.
disable_regex_jump_forward
=
True
self
.
disable_flashinfer
=
False
self
.
disable_cuda_graph
=
True
self
.
chunked_prefill_size
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
python/sglang/test/long_prompt
0 → 100644
View file @
0909bb0d
This diff is collapsed.
Click to expand it.
python/sglang/test/runners.py
View file @
0909bb0d
...
@@ -15,6 +15,7 @@ limitations under the License.
...
@@ -15,6 +15,7 @@ limitations under the License.
import
json
import
json
import
multiprocessing
import
multiprocessing
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Union
from
typing
import
List
,
Union
...
@@ -31,8 +32,14 @@ DEFAULT_PROMPTS = [
...
@@ -31,8 +32,14 @@ DEFAULT_PROMPTS = [
"The capital of the United Kindom is"
,
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
"Today is a sunny day and I like"
,
"AI is a field of computer science focused on"
,
"AI is a field of computer science focused on"
,
"Apple is red. Banana is Yellow. "
*
800
+
"Apple is"
,
]
]
dirpath
=
os
.
path
.
dirname
(
__file__
)
with
open
(
os
.
path
.
join
(
dirpath
,
"long_prompt"
),
"r"
)
as
f
:
long_prompt
=
f
.
read
()
DEFAULT_PROMPTS
.
append
(
long_prompt
)
NUM_TOP_LOGPROBS
=
5
NUM_TOP_LOGPROBS
=
5
...
@@ -125,16 +132,14 @@ class HFRunner:
...
@@ -125,16 +132,14 @@ class HFRunner:
)
)
logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
logprobs
=
F
.
log_softmax
(
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
logprobs
,
top_indices
=
torch
.
topk
(
).
tolist
()
logprobs
,
k
=
NUM_TOP_LOGPROBS
,
dim
=-
1
# index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
)
# print("index", index_of_max)
# print("index", top_indices)
logprobs
=
[
prefill_logprobs
.
append
(
logprobs
.
tolist
())
sorted
(
token_logprobs
,
reverse
=
True
)[:
NUM_TOP_LOGPROBS
]
del
logits
for
token_logprobs
in
logprobs
del
logprobs
]
prefill_logprobs
.
append
(
logprobs
)
out_queue
.
put
(
out_queue
.
put
(
ModelOutput
(
ModelOutput
(
...
@@ -186,6 +191,7 @@ class SRTRunner:
...
@@ -186,6 +191,7 @@ class SRTRunner:
tp_size
=
tp_size
,
tp_size
=
tp_size
,
dtype
=
get_dtype_str
(
torch_dtype
),
dtype
=
get_dtype_str
(
torch_dtype
),
port
=
port
,
port
=
port
,
mem_fraction_static
=
0.7
,
)
)
def
forward
(
def
forward
(
...
...
scripts/playground/reference_hf.py
View file @
0909bb0d
...
@@ -35,18 +35,17 @@ def normal_text(args):
...
@@ -35,18 +35,17 @@ def normal_text(args):
args
.
model_path
,
args
.
model_path
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
low_cpu_mem_usage
=
True
,
low_cpu_mem_usage
=
True
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
m
.
cuda
()
m
.
cuda
()
print
(
m
)
prompts
=
[
prompts
=
[
"The capital of France is"
,
"The capital of France is"
,
"The capital of the United Kindom is"
,
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
"Today is a sunny day and I like"
,
]
]
max_new_tokens
=
32
max_new_tokens
=
16
for
p
in
prompts
:
for
p
in
prompts
:
if
isinstance
(
p
,
str
):
if
isinstance
(
p
,
str
):
...
@@ -58,10 +57,11 @@ def normal_text(args):
...
@@ -58,10 +57,11 @@ def normal_text(args):
input_ids
,
do_sample
=
False
,
max_new_tokens
=
max_new_tokens
input_ids
,
do_sample
=
False
,
max_new_tokens
=
max_new_tokens
)
)
output_str
=
t
.
decode
(
output_ids
[
0
])
output_str
=
t
.
decode
(
output_ids
[
0
])
print
(
output_str
)
prefill_logits
=
m
.
forward
(
input_ids
).
logits
[
0
][
-
1
]
prefill_logits
=
m
.
forward
(
input_ids
).
logits
[
0
][
-
1
]
print
(
"prefill logits"
,
prefill_logits
)
print
(
"prefill logits"
,
prefill_logits
)
print
(
output_str
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
test/srt/models/test_embedding_models.py
View file @
0909bb0d
...
@@ -53,7 +53,9 @@ class TestEmbeddingModels(unittest.TestCase):
...
@@ -53,7 +53,9 @@ class TestEmbeddingModels(unittest.TestCase):
srt_logits
=
torch
.
Tensor
(
srt_outputs
.
embed_logits
[
i
])
srt_logits
=
torch
.
Tensor
(
srt_outputs
.
embed_logits
[
i
])
similarities
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
similarities
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
print
(
"max similarity diff"
,
torch
.
max
(
abs
(
similarities
-
1
)))
if
hf_logits
.
shape
[
0
]
<=
100
:
tolerance
=
1e-2
tolerance
=
1e-2
assert
torch
.
all
(
assert
torch
.
all
(
abs
(
similarities
-
1
)
<
tolerance
abs
(
similarities
-
1
)
<
tolerance
...
...
test/srt/models/test_generation_models.py
View file @
0909bb0d
...
@@ -20,8 +20,8 @@ import torch
...
@@ -20,8 +20,8 @@ import torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
MODELS
=
[
MODELS
=
[
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
),
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
,
1.
1
),
(
"google/gemma-2-2b"
,
1
),
(
"google/gemma-2-2b"
,
1
,
3
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
@@ -35,6 +35,7 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -35,6 +35,7 @@ class TestGenerationModels(unittest.TestCase):
tp_size
,
tp_size
,
torch_dtype
,
torch_dtype
,
max_new_tokens
,
max_new_tokens
,
long_context_tolerance
,
)
->
None
:
)
->
None
:
with
HFRunner
(
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
True
model_path
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
True
...
@@ -53,15 +54,19 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -53,15 +54,19 @@ class TestGenerationModels(unittest.TestCase):
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
print
(
"max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
)))
if
hf_logprobs
.
shape
[
0
]
<=
100
:
tolerance
=
3e-2
tolerance
=
3e-2
assert
torch
.
all
(
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
tolerance
abs
(
hf_logprobs
-
srt_logprobs
)
<
tolerance
),
f
"prefill logprobs not all close"
),
f
"prefill logprobs not all close"
print
(
hf_outputs
.
output_strs
)
print
(
srt_outputs
.
output_strs
)
assert
hf_outputs
.
output_strs
==
srt_outputs
.
output_strs
assert
hf_outputs
.
output_strs
==
srt_outputs
.
output_strs
def
test_prefill_logits
(
self
):
def
test_prefill_logits
_and_output_strs
(
self
):
for
model
,
tp_size
in
MODELS
:
for
model
,
tp_size
,
long_context_tolerance
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
for
torch_dtype
in
TORCH_DTYPES
:
max_new_tokens
=
8
max_new_tokens
=
8
self
.
assert_close_prefill_logits_and_output_strs
(
self
.
assert_close_prefill_logits_and_output_strs
(
...
@@ -70,6 +75,7 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -70,6 +75,7 @@ class TestGenerationModels(unittest.TestCase):
tp_size
,
tp_size
,
torch_dtype
,
torch_dtype
,
max_new_tokens
,
max_new_tokens
,
long_context_tolerance
=
long_context_tolerance
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment