Unverified Commit 308d0240 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[CI] Fix the issue of unit test hanging (#1211)

parent ab4990e4
......@@ -460,24 +460,25 @@ def run_with_timeout(
return ret_value[0]
def run_one_file(filename, out_queue):
print(f"\n\nRun {filename}\n\n")
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time()
success = True
for filename in files:
out_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=run_one_file, args=(filename, out_queue))
def func():
print(f"\n\nRun {filename}\n\n")
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
p = multiprocessing.Process(target=func)
def run_one_file():
def run_process():
p.start()
p.join()
try:
run_with_timeout(run_one_file, timeout=timeout_per_file)
run_with_timeout(run_process, timeout=timeout_per_file)
if p.exitcode != 0:
success = False
break
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import multiprocessing as mp
import unittest
import torch
......@@ -71,4 +72,9 @@ class TestEmbeddingModels(unittest.TestCase):
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import multiprocessing as mp
import unittest
import torch
......@@ -108,13 +109,6 @@ class TestGenerationModels(unittest.TestCase):
), f"Not all ROUGE-L scores are greater than {rouge_threshold}"
def test_prefill_logits_and_output_strs(self):
import multiprocessing as mp
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
for (
model,
tp_size,
......@@ -137,4 +131,9 @@ class TestGenerationModels(unittest.TestCase):
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")
import argparse
import glob
import multiprocessing as mp
from sglang.test.test_utils import run_unittest_files
......@@ -54,5 +55,10 @@ if __name__ == "__main__":
else:
files = suites[args.suite]
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
exit_code = run_unittest_files(files, args.timeout_per_file)
exit(exit_code)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment