Unverified Commit 25235779 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[ci][distributed] try to fix pp test (#7054)

parent 3bb4b1e4
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
import pytest import pytest
from ..utils import compare_two_settings from ..utils import compare_two_settings, fork_new_process_for_each_test
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
...@@ -28,6 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" ...@@ -28,6 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
]) ])
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND): DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp": if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
...@@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, ...@@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
"FLASH_ATTN", "FLASH_ATTN",
"FLASHINFER", "FLASHINFER",
]) ])
@fork_new_process_for_each_test
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND): def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
cudagraph_args = [ cudagraph_args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
......
import functools
import os import os
import signal
import subprocess import subprocess
import sys import sys
import time import time
...@@ -336,3 +338,40 @@ def wait_for_gpu_memory_to_clear(devices: List[int], ...@@ -336,3 +338,40 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
f'{dur_s=:.02f} ({threshold_bytes/2**30=})') f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
time.sleep(5) time.sleep(5)
def fork_new_process_for_each_test(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
pid = os.fork()
if pid == 0:
try:
f(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception:
import traceback
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_singla_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_singla_handler)
assert _exitcode == 0, (f"function {f} failed when called with"
f" args {args} and kwargs {kwargs}")
return wrapper
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
from typing import List, Optional from typing import List, Optional
try: try:
from ray.exceptions import ActorDiedError from ray.exceptions import ActorDiedError # type: ignore
except ImportError: except ImportError:
# For older versions of Ray # For older versions of Ray
from ray.exceptions import RayActorError as ActorDiedError # type: ignore from ray.exceptions import RayActorError as ActorDiedError # type: ignore
......
...@@ -928,7 +928,8 @@ def error_on_invalid_device_count_status(): ...@@ -928,7 +928,8 @@ def error_on_invalid_device_count_status():
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
# future pytorch will fix the issue, device_count will not be cached # future pytorch will fix the issue, device_count will not be cached
# at that time, `.cache_info().currsize` will error out # at that time, `.cache_info().currsize` will error out
cache_entries = torch.cuda.device_count.cache_info().currsize cache_entries = torch.cuda.device_count.cache_info( # type: ignore
).currsize
if cache_entries != 0: if cache_entries != 0:
# the function is already called, and the result is cached # the function is already called, and the result is cached
remembered = torch.cuda.device_count() remembered = torch.cuda.device_count()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment