Unverified Commit 11668533 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix cuda illegal memory access in overlap mode (#2070)

parent a9e90b4b
...@@ -1055,9 +1055,6 @@ class ScheduleBatch: ...@@ -1055,9 +1055,6 @@ class ScheduleBatch:
) )
def copy(self): def copy(self):
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = self.seq_lens[0].item()
# Only contain fields that will be used by process_batch_result # Only contain fields that will be used by process_batch_result
return ScheduleBatch( return ScheduleBatch(
reqs=self.reqs, reqs=self.reqs,
......
...@@ -390,6 +390,9 @@ class Scheduler: ...@@ -390,6 +390,9 @@ class Scheduler:
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = batch.seq_lens[0].item()
result = self.run_batch(batch) result = self.run_batch(batch)
result_queue.append((batch.copy(), result)) result_queue.append((batch.copy(), result))
......
...@@ -16,7 +16,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer ...@@ -16,7 +16,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.test.few_shot_gsm8k_engine import run_eval from sglang.test.few_shot_gsm8k_engine import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
) )
...@@ -43,7 +42,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -43,7 +42,7 @@ class TestSRTEngine(unittest.TestCase):
print("==== Answer 2 ====") print("==== Answer 2 ====")
print(out2) print(out2)
assert out1 == out2, f"{out1} != {out2}" self.assertEqual(out1, out2)
def test_2_engine_multiple_generate(self): def test_2_engine_multiple_generate(self):
# just to ensure there is no issue running multiple generate calls # just to ensure there is no issue running multiple generate calls
...@@ -106,14 +105,14 @@ class TestSRTEngine(unittest.TestCase): ...@@ -106,14 +105,14 @@ class TestSRTEngine(unittest.TestCase):
def test_4_gsm8k(self): def test_4_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
model_path=DEFAULT_MODEL_NAME_FOR_TEST, model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
local_data_path=None, local_data_path=None,
num_shots=5, num_shots=5,
num_questions=200, num_questions=200,
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["accuracy"] > 0.7 self.assertGreater(metrics["accuracy"], 0.3)
def test_5_prompt_input_ids_consistency(self): def test_5_prompt_input_ids_consistency(self):
prompt = "The capital of UK is" prompt = "The capital of UK is"
...@@ -136,7 +135,7 @@ class TestSRTEngine(unittest.TestCase): ...@@ -136,7 +135,7 @@ class TestSRTEngine(unittest.TestCase):
print("==== Answer 2 ====") print("==== Answer 2 ====")
print(out2) print(out2)
assert out1 == out2, f"{out1} != {out2}" self.assertEqual(out1, out2)
def test_6_engine_runtime_encode_consistency(self): def test_6_engine_runtime_encode_consistency(self):
prompt = "Today is a sunny day and I like" prompt = "Today is a sunny day and I like"
...@@ -156,11 +155,11 @@ class TestSRTEngine(unittest.TestCase): ...@@ -156,11 +155,11 @@ class TestSRTEngine(unittest.TestCase):
def test_7_engine_offline_throughput(self): def test_7_engine_offline_throughput(self):
server_args = ServerArgs( server_args = ServerArgs(
model_path=DEFAULT_MODEL_NAME_FOR_TEST, model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
) )
bench_args = BenchArgs(num_prompts=100) bench_args = BenchArgs(num_prompts=10)
result = throughput_test(server_args=server_args, bench_args=bench_args) result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertTrue(result["total_throughput"] > 3000) self.assertGreater(result["total_throughput"], 3500)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment