Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dc188132
"vscode:/vscode.git/clone" did not exist on "1d50dfa018f15678b2a46afb663c379079b75f26"
Unverified
Commit
dc188132
authored
Jan 20, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 20, 2025
Browse files
Fix perf regression on small batch sizes (#3008)
parent
10bfce71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
7 deletions
+11
-7
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+9
-5
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
dc188132
...
@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
...
@@ -47,8 +47,8 @@ class RadixAttention(nn.Module):
self
.
logit_cap
=
logit_cap
self
.
logit_cap
=
logit_cap
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
is_cross_attention
=
is_cross_attention
self
.
is_cross_attention
=
is_cross_attention
self
.
k_scale
=
1.0
self
.
k_scale
=
None
self
.
v_scale
=
1.0
self
.
v_scale
=
None
def
forward
(
def
forward
(
self
,
self
,
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
dc188132
...
@@ -27,7 +27,7 @@ import logging
...
@@ -27,7 +27,7 @@ import logging
import
threading
import
threading
from
enum
import
IntEnum
from
enum
import
IntEnum
from
functools
import
wraps
from
functools
import
wraps
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
psutil
import
psutil
...
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -270,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
loc
:
torch
.
Tensor
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
k_scale
:
float
=
1.0
,
k_scale
:
Optional
[
float
]
=
None
,
v_scale
:
float
=
1.0
,
v_scale
:
Optional
[
float
]
=
None
,
):
):
layer_id
=
layer
.
layer_id
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
(
cache_k
/
k_scale
).
to
(
self
.
dtype
)
if
k_scale
is
not
None
:
cache_v
=
(
cache_v
/
v_scale
).
to
(
self
.
dtype
)
cache_k
.
div_
(
k_scale
)
if
v_scale
is
not
None
:
cache_v
.
div_
(
v_scale
)
cache_k
=
cache_k
.
to
(
self
.
dtype
)
cache_v
=
cache_v
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
.
view
(
self
.
store_dtype
)
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
.
view
(
self
.
store_dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment