Unverified Commit 308d0240 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[CI] Fix the issue of unit test hanging (#1211)

parent ab4990e4
...@@ -460,24 +460,25 @@ def run_with_timeout( ...@@ -460,24 +460,25 @@ def run_with_timeout(
return ret_value[0] return ret_value[0]
def run_one_file(filename, out_queue):
print(f"\n\nRun {filename}\n\n")
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
def run_unittest_files(files: List[str], timeout_per_file: float): def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time() tic = time.time()
success = True success = True
for filename in files: for filename in files:
out_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=run_one_file, args=(filename, out_queue))
def func(): def run_process():
print(f"\n\nRun {filename}\n\n")
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
p = multiprocessing.Process(target=func)
def run_one_file():
p.start() p.start()
p.join() p.join()
try: try:
run_with_timeout(run_one_file, timeout=timeout_per_file) run_with_timeout(run_process, timeout=timeout_per_file)
if p.exitcode != 0: if p.exitcode != 0:
success = False success = False
break break
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import multiprocessing as mp
import unittest import unittest
import torch import torch
...@@ -71,4 +72,9 @@ class TestEmbeddingModels(unittest.TestCase): ...@@ -71,4 +72,9 @@ class TestEmbeddingModels(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore") unittest.main(warnings="ignore")
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import multiprocessing as mp
import unittest import unittest
import torch import torch
...@@ -108,13 +109,6 @@ class TestGenerationModels(unittest.TestCase): ...@@ -108,13 +109,6 @@ class TestGenerationModels(unittest.TestCase):
), f"Not all ROUGE-L scores are greater than {rouge_threshold}" ), f"Not all ROUGE-L scores are greater than {rouge_threshold}"
def test_prefill_logits_and_output_strs(self): def test_prefill_logits_and_output_strs(self):
import multiprocessing as mp
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
for ( for (
model, model,
tp_size, tp_size,
...@@ -137,4 +131,9 @@ class TestGenerationModels(unittest.TestCase): ...@@ -137,4 +131,9 @@ class TestGenerationModels(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore") unittest.main(warnings="ignore")
import argparse import argparse
import glob import glob
import multiprocessing as mp
from sglang.test.test_utils import run_unittest_files from sglang.test.test_utils import run_unittest_files
...@@ -54,5 +55,10 @@ if __name__ == "__main__": ...@@ -54,5 +55,10 @@ if __name__ == "__main__":
else: else:
files = suites[args.suite] files = suites[args.suite]
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
exit_code = run_unittest_files(files, args.timeout_per_file) exit_code = run_unittest_files(files, args.timeout_per_file)
exit(exit_code) exit(exit_code)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment