Unverified Commit 11668533 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix cuda illegal memory access in overlap mode (#2070)

parent a9e90b4b
......@@ -1055,9 +1055,6 @@ class ScheduleBatch:
)
def copy(self):
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = self.seq_lens[0].item()
# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
......
......@@ -390,6 +390,9 @@ class Scheduler:
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = batch.seq_lens[0].item()
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
......
......@@ -16,7 +16,6 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server_args import ServerArgs
from sglang.test.few_shot_gsm8k_engine import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
......@@ -43,7 +42,7 @@ class TestSRTEngine(unittest.TestCase):
print("==== Answer 2 ====")
print(out2)
assert out1 == out2, f"{out1} != {out2}"
self.assertEqual(out1, out2)
def test_2_engine_multiple_generate(self):
# just to ensure there is no issue running multiple generate calls
......@@ -106,14 +105,14 @@ class TestSRTEngine(unittest.TestCase):
def test_4_gsm8k(self):
args = SimpleNamespace(
model_path=DEFAULT_MODEL_NAME_FOR_TEST,
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
local_data_path=None,
num_shots=5,
num_questions=200,
)
metrics = run_eval(args)
assert metrics["accuracy"] > 0.7
self.assertGreater(metrics["accuracy"], 0.3)
def test_5_prompt_input_ids_consistency(self):
prompt = "The capital of UK is"
......@@ -136,7 +135,7 @@ class TestSRTEngine(unittest.TestCase):
print("==== Answer 2 ====")
print(out2)
assert out1 == out2, f"{out1} != {out2}"
self.assertEqual(out1, out2)
def test_6_engine_runtime_encode_consistency(self):
prompt = "Today is a sunny day and I like"
......@@ -156,11 +155,11 @@ class TestSRTEngine(unittest.TestCase):
def test_7_engine_offline_throughput(self):
server_args = ServerArgs(
model_path=DEFAULT_MODEL_NAME_FOR_TEST,
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
bench_args = BenchArgs(num_prompts=100)
bench_args = BenchArgs(num_prompts=10)
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertTrue(result["total_throughput"] > 3000)
self.assertGreater(result["total_throughput"], 3500)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment