Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11668533
Unverified
Commit
11668533
authored
Nov 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 17, 2024
Browse files
Fix cuda illegal memory access in overlap mode (#2070)
parent
a9e90b4b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
11 deletions
+10
-11
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+0
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-0
test/srt/test_srt_engine.py
test/srt/test_srt_engine.py
+7
-8
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
11668533
...
@@ -1055,9 +1055,6 @@ class ScheduleBatch:
...
@@ -1055,9 +1055,6 @@ class ScheduleBatch:
)
)
def
copy
(
self
):
def
copy
(
self
):
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_
=
self
.
seq_lens
[
0
].
item
()
# Only contain fields that will be used by process_batch_result
# Only contain fields that will be used by process_batch_result
return
ScheduleBatch
(
return
ScheduleBatch
(
reqs
=
self
.
reqs
,
reqs
=
self
.
reqs
,
...
...
python/sglang/srt/managers/scheduler.py
View file @
11668533
...
@@ -390,6 +390,9 @@ class Scheduler:
...
@@ -390,6 +390,9 @@ class Scheduler:
batch
=
self
.
get_next_batch_to_run
()
batch
=
self
.
get_next_batch_to_run
()
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
if
batch
:
if
batch
:
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_
=
batch
.
seq_lens
[
0
].
item
()
result
=
self
.
run_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
result_queue
.
append
((
batch
.
copy
(),
result
))
result_queue
.
append
((
batch
.
copy
(),
result
))
...
...
test/srt/test_srt_engine.py
View file @
11668533
...
@@ -16,7 +16,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
...
@@ -16,7 +16,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.test.few_shot_gsm8k_engine
import
run_eval
from
sglang.test.few_shot_gsm8k_engine
import
run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
)
)
...
@@ -43,7 +42,7 @@ class TestSRTEngine(unittest.TestCase):
...
@@ -43,7 +42,7 @@ class TestSRTEngine(unittest.TestCase):
print
(
"==== Answer 2 ===="
)
print
(
"==== Answer 2 ===="
)
print
(
out2
)
print
(
out2
)
assert
out1
==
out2
,
f
"
{
out1
}
!=
{
out2
}
"
self
.
assert
Equal
(
out1
,
out2
)
def
test_2_engine_multiple_generate
(
self
):
def
test_2_engine_multiple_generate
(
self
):
# just to ensure there is no issue running multiple generate calls
# just to ensure there is no issue running multiple generate calls
...
@@ -106,14 +105,14 @@ class TestSRTEngine(unittest.TestCase):
...
@@ -106,14 +105,14 @@ class TestSRTEngine(unittest.TestCase):
def
test_4_gsm8k
(
self
):
def
test_4_gsm8k
(
self
):
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
,
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
local_data_path
=
None
,
local_data_path
=
None
,
num_shots
=
5
,
num_shots
=
5
,
num_questions
=
200
,
num_questions
=
200
,
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"accuracy"
]
>
0.
7
self
.
assert
Greater
(
metrics
[
"accuracy"
]
,
0.
3
)
def
test_5_prompt_input_ids_consistency
(
self
):
def
test_5_prompt_input_ids_consistency
(
self
):
prompt
=
"The capital of UK is"
prompt
=
"The capital of UK is"
...
@@ -136,7 +135,7 @@ class TestSRTEngine(unittest.TestCase):
...
@@ -136,7 +135,7 @@ class TestSRTEngine(unittest.TestCase):
print
(
"==== Answer 2 ===="
)
print
(
"==== Answer 2 ===="
)
print
(
out2
)
print
(
out2
)
assert
out1
==
out2
,
f
"
{
out1
}
!=
{
out2
}
"
self
.
assert
Equal
(
out1
,
out2
)
def
test_6_engine_runtime_encode_consistency
(
self
):
def
test_6_engine_runtime_encode_consistency
(
self
):
prompt
=
"Today is a sunny day and I like"
prompt
=
"Today is a sunny day and I like"
...
@@ -156,11 +155,11 @@ class TestSRTEngine(unittest.TestCase):
...
@@ -156,11 +155,11 @@ class TestSRTEngine(unittest.TestCase):
def
test_7_engine_offline_throughput
(
self
):
def
test_7_engine_offline_throughput
(
self
):
server_args
=
ServerArgs
(
server_args
=
ServerArgs
(
model_path
=
DEFAULT_MODEL_NAME_FOR_TEST
,
model_path
=
DEFAULT_
SMALL_
MODEL_NAME_FOR_TEST
,
)
)
bench_args
=
BenchArgs
(
num_prompts
=
10
0
)
bench_args
=
BenchArgs
(
num_prompts
=
10
)
result
=
throughput_test
(
server_args
=
server_args
,
bench_args
=
bench_args
)
result
=
throughput_test
(
server_args
=
server_args
,
bench_args
=
bench_args
)
self
.
assert
True
(
result
[
"total_throughput"
]
>
3
0
00
)
self
.
assert
Greater
(
result
[
"total_throughput"
]
,
3
5
00
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment