Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dc188132
Unverified
Commit
dc188132
authored
Jan 20, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 20, 2025
Browse files
Fix perf regression on small batch sizes (#3008)
parent
10bfce71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
7 deletions
+11
-7
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+9
-5
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
dc188132
...
@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
...
@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
self
.
logit_cap
=
logit_cap
self
.
logit_cap
=
logit_cap
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
is_cross_attention
=
is_cross_attention
self
.
is_cross_attention
=
is_cross_attention
self
.
k_scale
=
1.0
self
.
k_scale
=
None
self
.
v_scale
=
1.0
self
.
v_scale
=
None
def
forward
(
def
forward
(
self
,
self
,
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
dc188132
...
@@ -27,7 +27,7 @@ import logging
...
@@ -27,7 +27,7 @@ import logging
import
threading
import
threading
from
enum
import
IntEnum
from
enum
import
IntEnum
from
functools
import
wraps
from
functools
import
wraps
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
psutil
import
psutil
...
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
k_scale
:
float
=
1.0
,
k_scale
:
Optional
[
float
]
=
None
,
v_scale
:
float
=
1.0
,
v_scale
:
Optional
[
float
]
=
None
,
):
):
layer_id
=
layer
.
layer_id
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
(
cache_k
/
k_scale
).
to
(
self
.
dtype
)
if
k_scale
is
not
None
:
cache_v
=
(
cache_v
/
v_scale
).
to
(
self
.
dtype
)
cache_k
.
div_
(
k_scale
)
if
v_scale
is
not
None
:
cache_v
.
div_
(
v_scale
)
cache_k
=
cache_k
.
to
(
self
.
dtype
)
cache_v
=
cache_v
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
.
view
(
self
.
store_dtype
)
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
.
view
(
self
.
store_dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment