Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5ac8b806
Unverified
Commit
5ac8b806
authored
Jul 15, 2024
by
Mingyi
Committed by
GitHub
Jul 15, 2024
Browse files
Simplify mem state (#623)
parent
bae9541e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
61 additions
and
66 deletions
+61
-66
benchmark/latency_throughput/bench_serving.py
benchmark/latency_throughput/bench_serving.py
+2
-1
python/sglang/global_config.py
python/sglang/global_config.py
+2
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+4
-8
python/sglang/srt/managers/controller/radix_cache.py
python/sglang/srt/managers/controller/radix_cache.py
+2
-2
python/sglang/srt/managers/controller/schedule_heuristic.py
python/sglang/srt/managers/controller/schedule_heuristic.py
+4
-0
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+26
-30
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+21
-24
No files found.
benchmark/latency_throughput/bench_serving.py
View file @
5ac8b806
...
@@ -297,7 +297,8 @@ def main(args: argparse.Namespace):
...
@@ -297,7 +297,8 @@ def main(args: argparse.Namespace):
benchmark_time
=
benchmark_end_time
-
benchmark_start_time
benchmark_time
=
benchmark_end_time
-
benchmark_start_time
# Compute the statistics.
# Compute the statistics.
avg_latency
=
np
.
mean
([
latency
for
_
,
_
,
latency
in
REQUEST_LATENCY
])
latencies
=
[
latency
for
_
,
_
,
latency
in
REQUEST_LATENCY
]
avg_latency
=
np
.
mean
(
latencies
)
avg_per_token_latency
=
np
.
mean
(
avg_per_token_latency
=
np
.
mean
(
[
[
latency
/
(
prompt_len
+
output_len
)
latency
/
(
prompt_len
+
output_len
)
...
...
python/sglang/global_config.py
View file @
5ac8b806
...
@@ -25,7 +25,8 @@ class GlobalConfig:
...
@@ -25,7 +25,8 @@ class GlobalConfig:
# This can improve the speed for large batch sizes during prefill.
# This can improve the speed for large batch sizes during prefill.
self
.
layer_sync_threshold
=
8192
self
.
layer_sync_threshold
=
8192
# Runtime constants: Flashinfer
# Runtime constants: others
self
.
num_continue_decode_steps
=
10
self
.
flashinfer_workspace_size
=
192
*
1024
*
1024
self
.
flashinfer_workspace_size
=
192
*
1024
*
1024
# Output tokenization configs
# Output tokenization configs
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
5ac8b806
...
@@ -174,9 +174,6 @@ class Req:
...
@@ -174,9 +174,6 @@ class Req:
return
False
,
""
return
False
,
""
def
max_new_tokens
(
self
):
return
self
.
sampling_params
.
max_new_tokens
def
check_finished
(
self
):
def
check_finished
(
self
):
if
self
.
finished
():
if
self
.
finished
():
return
return
...
@@ -352,7 +349,7 @@ class Batch:
...
@@ -352,7 +349,7 @@ class Batch:
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
if
out_cache_loc
is
None
:
if
out_cache_loc
is
None
:
self
.
tree_cache
.
evict
(
extend_num_tokens
,
self
.
token_to_kv_pool
.
dec_refs
)
self
.
tree_cache
.
evict
(
extend_num_tokens
,
self
.
token_to_kv_pool
.
free
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
if
out_cache_loc
is
None
:
if
out_cache_loc
is
None
:
...
@@ -422,7 +419,7 @@ class Batch:
...
@@ -422,7 +419,7 @@ class Batch:
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
return
True
return
True
self
.
tree_cache
.
evict
(
bs
,
self
.
token_to_kv_pool
.
dec_refs
)
self
.
tree_cache
.
evict
(
bs
,
self
.
token_to_kv_pool
.
free
)
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
return
True
return
True
...
@@ -453,7 +450,7 @@ class Batch:
...
@@ -453,7 +450,7 @@ class Batch:
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
idx
]
req_pool_indices_cpu
[
idx
]
][
last_uncached_pos
:
seq_lens_cpu
[
idx
]]
][
last_uncached_pos
:
seq_lens_cpu
[
idx
]]
self
.
token_to_kv_pool
.
dec_refs
(
token_indices
)
self
.
token_to_kv_pool
.
free
(
token_indices
)
# release the last node
# release the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
...
@@ -596,8 +593,7 @@ class Batch:
...
@@ -596,8 +593,7 @@ class Batch:
"logit_bias"
,
"logit_bias"
,
]:
]:
self_val
=
getattr
(
self
,
item
,
None
)
self_val
=
getattr
(
self
,
item
,
None
)
# logit_bias can be None
if
self_val
is
not
None
:
# logit_bias can be None
if
self_val
is
not
None
:
setattr
(
self
,
item
,
self_val
[
new_indices
])
setattr
(
self
,
item
,
self_val
[
new_indices
])
def
merge
(
self
,
other
:
"Batch"
):
def
merge
(
self
,
other
:
"Batch"
):
...
...
python/sglang/srt/managers/controller/radix_cache.py
View file @
5ac8b806
...
@@ -82,12 +82,12 @@ class RadixCache:
...
@@ -82,12 +82,12 @@ class RadixCache:
if
self
.
disable
:
if
self
.
disable
:
if
del_in_memory_pool
:
if
del_in_memory_pool
:
self
.
token_to_kv_pool
.
dec_refs
(
indices
)
self
.
token_to_kv_pool
.
free
(
indices
)
else
:
else
:
return
torch
.
tensor
([],
dtype
=
torch
.
int64
),
self
.
root_node
return
torch
.
tensor
([],
dtype
=
torch
.
int64
),
self
.
root_node
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
self
.
token_to_kv_pool
.
dec_refs
(
indices
[
last_uncached_pos
:
new_prefix_len
])
self
.
token_to_kv_pool
.
free
(
indices
[
last_uncached_pos
:
new_prefix_len
])
if
del_in_memory_pool
:
if
del_in_memory_pool
:
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
...
...
python/sglang/srt/managers/controller/schedule_heuristic.py
View file @
5ac8b806
...
@@ -13,6 +13,10 @@ class ScheduleHeuristic:
...
@@ -13,6 +13,10 @@ class ScheduleHeuristic:
max_total_num_tokens
,
max_total_num_tokens
,
tree_cache
,
tree_cache
,
):
):
if
tree_cache
.
disable
and
schedule_heuristic
==
"lpm"
:
# LMP is not meaningless when tree cache is disabled.
schedule_heuristic
=
"fcfs"
self
.
schedule_heuristic
=
schedule_heuristic
self
.
schedule_heuristic
=
schedule_heuristic
self
.
max_running_seqs
=
max_running_seqs
self
.
max_running_seqs
=
max_running_seqs
self
.
max_prefill_num_tokens
=
max_prefill_num_tokens
self
.
max_prefill_num_tokens
=
max_prefill_num_tokens
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
5ac8b806
...
@@ -98,7 +98,7 @@ class ModelTpServer:
...
@@ -98,7 +98,7 @@ class ModelTpServer:
)
)
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
(
self
.
max_prefill_tokens
=
(
8192
16384
if
server_args
.
max_prefill_tokens
is
None
if
server_args
.
max_prefill_tokens
is
None
else
server_args
.
max_prefill_tokens
else
server_args
.
max_prefill_tokens
)
)
...
@@ -222,30 +222,29 @@ class ModelTpServer:
...
@@ -222,30 +222,29 @@ class ModelTpServer:
# Run decode batch
# Run decode batch
if
self
.
running_batch
is
not
None
:
if
self
.
running_batch
is
not
None
:
# Run a few decode batches continuously for reducing overhead
# Run a few decode batches continuously for reducing overhead
for
_
in
range
(
10
):
for
_
in
range
(
global_config
.
num_continue_decode_steps
):
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
self
.
running_batch
.
reqs
)
self
.
forward_decode_batch
(
self
.
running_batch
)
self
.
forward_decode_batch
(
self
.
running_batch
)
# Print stats
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
if
self
.
decode_forward_ct
%
40
==
0
:
num_used
=
self
.
max_total_num_tokens
-
(
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
throughput
=
self
.
num_generated_tokens
/
(
throughput
=
self
.
num_generated_tokens
/
(
time
.
time
()
-
self
.
last_stats_tic
time
.
time
()
-
self
.
last_stats_tic
)
)
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
self
.
last_stats_tic
=
time
.
time
()
logger
.
info
(
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Decode batch. "
f
"[gpu_id=
{
self
.
gpu_id
}
] Decode batch. "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#running-req:
{
len
(
self
.
running_batch
.
reqs
)
}
, "
f
"#token:
{
num_used
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"gen throughput (token/s):
{
throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
"
f
"#queue-req:
{
len
(
self
.
forward_queue
)
}
"
)
)
if
self
.
running_batch
.
is_empty
():
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
self
.
running_batch
=
None
...
@@ -344,7 +343,7 @@ class ModelTpServer:
...
@@ -344,7 +343,7 @@ class ModelTpServer:
if
self
.
running_batch
:
if
self
.
running_batch
:
available_size
-=
sum
(
available_size
-=
sum
(
[
[
(
r
.
max_new_tokens
()
-
len
(
r
.
output_ids
))
*
self
.
new_token_ratio
(
r
.
sampling_params
.
max_new_tokens
-
len
(
r
.
output_ids
))
*
self
.
new_token_ratio
for
r
in
self
.
running_batch
.
reqs
for
r
in
self
.
running_batch
.
reqs
]
]
)
)
...
@@ -358,7 +357,7 @@ class ModelTpServer:
...
@@ -358,7 +357,7 @@ class ModelTpServer:
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
delta
]
if
req
.
image_offset
is
not
None
:
if
req
.
image_offset
is
not
None
:
req
.
image_offset
+=
delta
req
.
image_offset
+=
delta
if
req
.
extend_input_len
==
0
and
req
.
max_new_tokens
()
>
0
:
if
req
.
extend_input_len
==
0
and
req
.
sampling_params
.
max_new_tokens
>
0
:
# Need at least one token to compute logits
# Need at least one token to compute logits
req
.
extend_input_len
=
1
req
.
extend_input_len
=
1
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
req
.
prefix_indices
=
req
.
prefix_indices
[:
-
1
]
...
@@ -366,7 +365,7 @@ class ModelTpServer:
...
@@ -366,7 +365,7 @@ class ModelTpServer:
req
.
image_offset
+=
1
req
.
image_offset
+=
1
if
(
if
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
+
new_batch_total_tokens
<
available_size
<
available_size
and
(
and
(
req
.
extend_input_len
+
new_batch_input_tokens
req
.
extend_input_len
+
new_batch_input_tokens
...
@@ -378,7 +377,7 @@ class ModelTpServer:
...
@@ -378,7 +377,7 @@ class ModelTpServer:
available_size
+=
delta
available_size
+=
delta
if
not
(
if
not
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
+
new_batch_total_tokens
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
+
new_batch_total_tokens
<
available_size
<
available_size
):
):
# Undo locking
# Undo locking
...
@@ -389,7 +388,7 @@ class ModelTpServer:
...
@@ -389,7 +388,7 @@ class ModelTpServer:
# Add this request to the running batch
# Add this request to the running batch
can_run_list
.
append
(
req
)
can_run_list
.
append
(
req
)
new_batch_total_tokens
+=
(
new_batch_total_tokens
+=
(
req
.
extend_input_len
+
req
.
max_new_tokens
()
req
.
extend_input_len
+
req
.
sampling_params
.
max_new_tokens
)
)
new_batch_input_tokens
+=
req
.
extend_input_len
new_batch_input_tokens
+=
req
.
extend_input_len
else
:
else
:
...
@@ -403,9 +402,6 @@ class ModelTpServer:
...
@@ -403,9 +402,6 @@ class ModelTpServer:
# Print stats
# Print stats
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
running_req
=
(
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
)
hit_tokens
=
sum
(
len
(
x
.
prefix_indices
)
for
x
in
can_run_list
)
hit_tokens
=
sum
(
len
(
x
.
prefix_indices
)
for
x
in
can_run_list
)
self
.
tree_cache_metrics
[
"total"
]
+=
(
self
.
tree_cache_metrics
[
"total"
]
+=
(
hit_tokens
+
new_batch_input_tokens
hit_tokens
+
new_batch_input_tokens
...
@@ -420,7 +416,7 @@ class ModelTpServer:
...
@@ -420,7 +416,7 @@ class ModelTpServer:
f
"#new-token:
{
new_batch_input_tokens
}
, "
f
"#new-token:
{
new_batch_input_tokens
}
, "
f
"#cached-token:
{
hit_tokens
}
, "
f
"#cached-token:
{
hit_tokens
}
, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"cache hit rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%, "
f
"#running-req:
{
running_
req
}
, "
f
"#running-req:
{
running_
bs
}
, "
f
"#queue-req:
{
len
(
self
.
forward_queue
)
-
len
(
can_run_list
)
}
"
f
"#queue-req:
{
len
(
self
.
forward_queue
)
-
len
(
can_run_list
)
}
"
)
)
# logger.debug(
# logger.debug(
...
...
python/sglang/srt/memory_pool.py
View file @
5ac8b806
...
@@ -8,45 +8,45 @@ logger = logging.getLogger(__name__)
...
@@ -8,45 +8,45 @@ logger = logging.getLogger(__name__)
class
ReqToTokenPool
:
class
ReqToTokenPool
:
def
__init__
(
self
,
size
,
max_context_len
):
"""A memory pool that maps a request to its token locations."""
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
):
self
.
mem_state
=
torch
.
ones
((
size
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
self
.
mem_state
=
torch
.
ones
((
size
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
self
.
can_use_mem_size
=
size
self
.
req_to_token
=
torch
.
empty
(
self
.
req_to_token
=
torch
.
empty
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
self
.
can_use_mem_size
=
size
def
alloc
(
self
,
need_size
):
def
alloc
(
self
,
need_size
:
int
):
if
need_size
>
self
.
can_use_mem_size
:
if
need_size
>
self
.
can_use_mem_size
:
return
None
return
None
select_index
=
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
)[:
need_size
]
select_index
=
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
)[:
need_size
]
.
to
(
torch
.
int32
)
self
.
mem_state
[
select_index
]
=
False
self
.
mem_state
[
select_index
]
=
False
self
.
can_use_mem_size
-=
need_size
self
.
can_use_mem_size
-=
need_size
return
select_index
.
to
(
torch
.
int32
)
return
select_index
def
free
(
self
,
free_index
):
def
free
(
self
,
free_index
:
int
):
self
.
mem_state
[
free_index
]
=
True
if
isinstance
(
free_index
,
(
int
,)):
if
isinstance
(
free_index
,
(
int
,)):
self
.
can_use_mem_size
+=
1
self
.
can_use_mem_size
+=
1
else
:
else
:
self
.
can_use_mem_size
+=
free_index
.
shape
[
0
]
self
.
can_use_mem_size
+=
free_index
.
shape
[
0
]
self
.
mem_state
[
free_index
]
=
True
def
clear
(
self
):
def
clear
(
self
):
self
.
mem_state
.
fill_
(
True
)
self
.
mem_state
.
fill_
(
True
)
self
.
can_use_mem_size
=
len
(
self
.
mem_state
)
self
.
can_use_mem_size
=
len
(
self
.
mem_state
)
class
TokenToKVPool
:
class
TokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
self
.
size
=
size
self
.
size
=
size
# This can be promised:
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
# We also add one slot. This slot is used for writing dummy output from padded tokens.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
ones
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
self
.
mem_state
=
torch
.
ones
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
self
.
can_use_mem_size
=
self
.
size
# [size, key/value, head_num, head_dim] for each layer
# [size, key/value, head_num, head_dim] for each layer
self
.
kv_data
=
[
self
.
kv_data
=
[
...
@@ -58,6 +58,7 @@ class TokenToKVPool:
...
@@ -58,6 +58,7 @@ class TokenToKVPool:
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_chunk_size
=
512
self
.
prefetch_chunk_size
=
512
self
.
can_use_mem_size
=
self
.
size
self
.
clear
()
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
):
def
get_key_buffer
(
self
,
layer_id
):
...
@@ -66,6 +67,9 @@ class TokenToKVPool:
...
@@ -66,6 +67,9 @@ class TokenToKVPool:
def
get_value_buffer
(
self
,
layer_id
):
def
get_value_buffer
(
self
,
layer_id
):
return
self
.
kv_data
[
layer_id
][:,
1
]
return
self
.
kv_data
[
layer_id
][:,
1
]
def
available_size
(
self
):
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
def
alloc
(
self
,
need_size
):
def
alloc
(
self
,
need_size
):
buffer_len
=
len
(
self
.
prefetch_buffer
)
buffer_len
=
len
(
self
.
prefetch_buffer
)
if
need_size
<=
buffer_len
:
if
need_size
<=
buffer_len
:
...
@@ -75,13 +79,13 @@ class TokenToKVPool:
...
@@ -75,13 +79,13 @@ class TokenToKVPool:
addition_size
=
need_size
-
buffer_len
addition_size
=
need_size
-
buffer_len
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
select_index
=
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
)[:
alloc_size
]
select_index
=
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
)[:
alloc_size
].
to
(
torch
.
int32
)
select_index
=
select_index
.
to
(
torch
.
int32
)
if
select_index
.
shape
[
0
]
<
addition_size
:
if
select_index
.
shape
[
0
]
<
addition_size
:
return
None
return
None
self
.
add_refs
(
select_index
)
self
.
mem_state
[
select_index
]
=
False
self
.
can_use_mem_size
-=
len
(
select_index
)
self
.
prefetch_buffer
=
torch
.
cat
((
self
.
prefetch_buffer
,
select_index
))
self
.
prefetch_buffer
=
torch
.
cat
((
self
.
prefetch_buffer
,
select_index
))
ret_index
=
self
.
prefetch_buffer
[:
need_size
]
ret_index
=
self
.
prefetch_buffer
[:
need_size
]
...
@@ -89,16 +93,9 @@ class TokenToKVPool:
...
@@ -89,16 +93,9 @@ class TokenToKVPool:
return
ret_index
return
ret_index
def
available_size
(
self
):
def
free
(
self
,
free_index
:
torch
.
Tensor
):
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
self
.
mem_state
[
free_index
]
=
True
self
.
can_use_mem_size
+=
len
(
free_index
)
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
can_use_mem_size
-=
len
(
token_index
)
self
.
mem_state
[
token_index
]
=
False
def
dec_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
can_use_mem_size
+=
len
(
token_index
)
self
.
mem_state
[
token_index
]
=
True
def
clear
(
self
):
def
clear
(
self
):
self
.
mem_state
.
fill_
(
True
)
self
.
mem_state
.
fill_
(
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment