async_utils.py 2.04 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""
This module provides a singleton instance of AsyncCallsQueue which manages
the async checkpoint save calls.
"""
import logging

from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest
from megatron.training import get_args
from megatron.training.utils import print_rank_0

logger = logging.getLogger(__name__)

# Singleton manager of async calls
wangxj's avatar
wangxj committed
16
# The default is `TemporalAsyncCaller`
xingjinliang's avatar
xingjinliang committed
17
18
19
_async_calls_queue = AsyncCallsQueue()


wangxj's avatar
wangxj committed
20
21
22
23
24
25
26
def init_persistent_async_worker():
    global _async_calls_queue
    # Recreate the async_calls_queue for persistent worker
    # This duplicate step is for backward compatiblity
    _async_calls_queue = AsyncCallsQueue(persistent=True)


xingjinliang's avatar
xingjinliang committed
27
def schedule_async_save(async_request: AsyncRequest):
wangxj's avatar
wangxj committed
28
    """Schedule the async save request.
xingjinliang's avatar
xingjinliang committed
29
30
31
32
33
34
35

    Args:
        async_request (AsyncRequest): the async save request.
    """
    _async_calls_queue.schedule_async_request(async_request)


wangxj's avatar
wangxj committed
36
37
def maybe_finalize_async_save(blocking: bool = False, terminate=False):
    """Finalizes active async save calls.
xingjinliang's avatar
xingjinliang committed
38
39
40
41
42

    Args:
        blocking (bool, optional): if True, will wait until all active requests
            are done. Otherwise, finalizes only the async request that already
            finished. Defaults to False.
wangxj's avatar
wangxj committed
43
44
        terminate (bool, optional): if True, the asynchronous queue will
                be closed as the last action of this function.
xingjinliang's avatar
xingjinliang committed
45
46
47
48
49
    """
    args = get_args()
    if not args.async_save:
        return

wangxj's avatar
wangxj committed
50
    if blocking and not is_empty_async_queue():
xingjinliang's avatar
xingjinliang committed
51
52
        print_rank_0('Unfinalized async checkpoint saves. Finalizing them synchronously now.')

wangxj's avatar
wangxj committed
53
54
55
56
57
58
59
60
61
62
63
64
65
    _async_calls_queue.maybe_finalize_async_calls(blocking, no_dist=False)

    if terminate:
        _async_calls_queue.close()


def is_empty_async_queue() -> bool:
    """Check if async calls queue is empty. This result is consistent across ranks.

    Returns:
        bool: True if there is any ongoing async call.
    """
    return _async_calls_queue.get_num_unfinalized_calls() == 0