Unverified Commit 9c195fe2 authored by Crutcher Dunnavant's avatar Crutcher Dunnavant Committed by GitHub
Browse files

[minor] ThreadLocal to ThreadLocalCheckpointState dataclass (#1007)

* ThreadLocal to ThreadLocalCheckpointState dataclass

* remove notes
parent 32b0b98e
......@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from dataclasses import dataclass
import functools
import threading
from typing import Any, Dict, Generator, Optional, Tuple
......@@ -21,14 +22,14 @@ from .checkpoint_utils import patch_batchnorm
# https://docs.python.org/3/library/threading.html#thread-local-data
# Manage the checkpoint context with thread-local data.
class ThreadLocal(threading.local):
def __init__(self) -> None:
self.is_checkpointing = False
self.is_recomputing = False
self.is_checkpointing_disabled = False
@dataclass
class ThreadLocalCheckpointingState(threading.local):
is_checkpointing: bool = False
is_recomputing: bool = False
is_checkpointing_disabled: bool = False
thread_local = ThreadLocal()
thread_local = ThreadLocalCheckpointingState()
@contextmanager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment