Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6db27f7b
Unverified
Commit
6db27f7b
authored
Aug 07, 2024
by
Zhiqiang Xie
Committed by
GitHub
Aug 08, 2024
Browse files
misc: correct the int data type for token ids and indices (#969)
parent
4d929107
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
3 deletions
+3
-3
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-1
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+2
-2
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
6db27f7b
...
@@ -780,7 +780,7 @@ def top_k_top_p_sampling_from_probs_torch(
...
@@ -780,7 +780,7 @@ def top_k_top_p_sampling_from_probs_torch(
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
:
except
RuntimeError
:
batch_next_token_ids
=
torch
.
zeros
(
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int
64
,
device
=
probs
.
device
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int
32
,
device
=
probs
.
device
)
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
return
batch_next_token_ids
,
success
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
6db27f7b
...
@@ -74,7 +74,7 @@ class RadixCache(BasePrefixCache):
...
@@ -74,7 +74,7 @@ class RadixCache(BasePrefixCache):
if
value
:
if
value
:
value
=
torch
.
concat
(
value
)
value
=
torch
.
concat
(
value
)
else
:
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int
64
)
value
=
torch
.
tensor
([],
dtype
=
torch
.
int
32
)
return
value
,
last_node
[
0
]
return
value
,
last_node
[
0
]
def
insert
(
self
,
key
,
value
=
None
):
def
insert
(
self
,
key
,
value
=
None
):
...
@@ -102,7 +102,7 @@ class RadixCache(BasePrefixCache):
...
@@ -102,7 +102,7 @@ class RadixCache(BasePrefixCache):
if
del_in_memory_pool
:
if
del_in_memory_pool
:
self
.
token_to_kv_pool
.
free
(
indices
)
self
.
token_to_kv_pool
.
free
(
indices
)
else
:
else
:
return
torch
.
tensor
([],
dtype
=
torch
.
int
64
),
self
.
root_node
return
torch
.
tensor
([],
dtype
=
torch
.
int
32
),
self
.
root_node
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
self
.
token_to_kv_pool
.
free
(
indices
[
last_uncached_pos
:
new_prefix_len
])
self
.
token_to_kv_pool
.
free
(
indices
[
last_uncached_pos
:
new_prefix_len
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment