Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4d253057
"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "532f998b0f894268b69b7310bf06349e26b8543c"
Unverified
Commit
4d253057
authored
Mar 23, 2025
by
Zhiqiang Xie
Committed by
GitHub
Mar 23, 2025
Browse files
Move mem_state update into debug mode (#4525)
parent
11577ced
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
62 deletions
+70
-62
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+70
-62
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
4d253057
...
...
@@ -580,13 +580,20 @@ class MemoryStateInt(IntEnum):
BACKUP
=
4
def
synchronized
(
func
):
@
wraps
(
func
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
with
self
.
lock
:
return
func
(
self
,
*
args
,
**
kwargs
)
def
synchronized
(
debug_only
=
False
):
def
_decorator
(
func
):
@
wraps
(
func
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
if
(
not
debug_only
)
or
self
.
debug
:
return
func
(
self
,
*
args
,
**
kwargs
)
with
self
.
lock
:
return
func
(
self
,
*
args
,
**
kwargs
)
else
:
return
True
return
wrapper
return
wrapper
return
_decorator
class
HostKVCache
(
abc
.
ABC
):
...
...
@@ -631,13 +638,9 @@ class HostKVCache(abc.ABC):
self
.
kv_buffer
=
self
.
init_kv_buffer
()
# Initialize memory states and tracking structures.
self
.
mem_state
=
torch
.
zeros
(
(
self
.
size
,),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
# A lock for synchronized operations on memory allocation and state transitions.
self
.
lock
=
threading
.
RLock
()
self
.
debug
=
logger
.
isEnabledFor
(
logging
.
DEBUG
)
self
.
clear
()
@
abc
.
abstractmethod
...
...
@@ -664,97 +667,102 @@ class HostKVCache(abc.ABC):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
raise
NotImplementedError
()
@
synchronized
@
synchronized
()
def
clear
(
self
):
self
.
mem_state
.
fill_
(
0
)
self
.
can_use_mem_size
=
self
.
size
# Initialize memory states and tracking structures.
self
.
mem_state
=
torch
.
zeros
(
(
self
.
size
,),
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int64
)
@
synchronized
def
get_state
(
self
,
indices
:
torch
.
Tensor
)
->
MemoryStateInt
:
assert
len
(
indices
)
>
0
,
"The indices should not be empty"
states
=
self
.
mem_state
[
indices
]
assert
(
states
==
states
[
0
]
).
all
(),
"The memory slots should have the same state {}"
.
format
(
states
)
return
MemoryStateInt
(
states
[
0
].
item
())
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
@
synchronized
@
synchronized
()
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
if
need_size
>
self
.
can_use_mem
_size
:
if
need_size
>
self
.
available
_size
()
:
return
None
# todo: de-fragementation
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
self
.
mem_state
[
select_index
]
=
MemoryStateInt
.
RESERVED
self
.
can_use_mem_size
-=
need_size
if
self
.
debug
:
self
.
mem_state
[
select_index
]
=
MemoryStateInt
.
RESERVED
return
select_index
@
synchronized
@
synchronized
()
def
free
(
self
,
indices
:
torch
.
Tensor
)
->
int
:
self
.
free_slots
=
torch
.
cat
([
self
.
free_slots
,
indices
])
if
self
.
debug
:
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
IDLE
return
len
(
indices
)
@
synchronized
(
debug_only
=
True
)
def
get_state
(
self
,
indices
:
torch
.
Tensor
)
->
MemoryStateInt
:
assert
len
(
indices
)
>
0
,
"The indices should not be empty"
states
=
self
.
mem_state
[
indices
]
assert
(
states
==
states
[
0
]
).
all
(),
"The memory slots should have the same state {}"
.
format
(
states
)
return
MemoryStateInt
(
states
[
0
].
item
())
@
synchronized
(
debug_only
=
True
)
def
is_reserved
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
RESERVED
@
synchronized
@
synchronized
(
debug_only
=
True
)
def
is_protected
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
PROTECTED
@
synchronized
@
synchronized
(
debug_only
=
True
)
def
is_synced
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
SYNCED
@
synchronized
@
synchronized
(
debug_only
=
True
)
def
is_backup
(
self
,
indices
:
torch
.
Tensor
)
->
bool
:
return
self
.
get_state
(
indices
)
==
MemoryStateInt
.
BACKUP
@
synchronized
@
synchronized
(
debug_only
=
True
)
def
update_backup
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_synced
(
indices
),
(
f
"The host memory slots should be in SYNCED state before turning into BACKUP. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
if
not
self
.
is_synced
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in SYNCED state before turning into BACKUP. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
BACKUP
@
synchronized
@
synchronized
(
debug_only
=
True
)
def
update_synced
(
self
,
indices
:
torch
.
Tensor
):
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
@
synchronized
@
synchronized
(
debug_only
=
True
)
def
protect_write
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_reserved
(
indices
),
(
f
"The host memory slots should be RESERVED before write operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
if
not
self
.
is_reserved
(
indices
):
raise
ValueError
(
f
"The host memory slots should be RESERVED before write operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
@
synchronized
(
debug_only
=
True
)
def
protect_load
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_backup
(
indices
),
(
f
"The host memory slots should be in BACKUP state before load operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
if
not
self
.
is_backup
(
indices
):
raise
ValueError
(
f
"The host memory slots should be in BACKUP state before load operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
PROTECTED
@
synchronized
@
synchronized
(
debug_only
=
True
)
def
complete_io
(
self
,
indices
:
torch
.
Tensor
):
assert
self
.
is_protected
(
indices
),
(
f
"The host memory slots should be PROTECTED during I/O operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
if
not
self
.
is_protected
(
indices
):
raise
ValueError
(
f
"The host memory slots should be PROTECTED during I/O operations. "
f
"Current state:
{
self
.
get_state
(
indices
)
}
"
)
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
SYNCED
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
@
synchronized
def
free
(
self
,
indices
:
torch
.
Tensor
)
->
int
:
self
.
mem_state
[
indices
]
=
MemoryStateInt
.
IDLE
self
.
free_slots
=
torch
.
cat
([
self
.
free_slots
,
indices
])
self
.
can_use_mem_size
+=
len
(
indices
)
return
len
(
indices
)
class
MHATokenToKVPoolHost
(
HostKVCache
):
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment