Unverified Commit 9c195fe2 authored by Crutcher Dunnavant's avatar Crutcher Dunnavant Committed by GitHub
Browse files

[minor] ThreadLocal to ThreadLocalCheckpointState dataclass (#1007)

* ThreadLocal to ThreadLocalCheckpointState dataclass

* remove notes
parent 32b0b98e
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
import functools import functools
import threading import threading
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple
...@@ -21,14 +22,14 @@ from .checkpoint_utils import patch_batchnorm ...@@ -21,14 +22,14 @@ from .checkpoint_utils import patch_batchnorm
# https://docs.python.org/3/library/threading.html#thread-local-data # https://docs.python.org/3/library/threading.html#thread-local-data
# Manage the checkpoint context with thread-local data. # Manage the checkpoint context with thread-local data.
class ThreadLocal(threading.local): @dataclass
def __init__(self) -> None: class ThreadLocalCheckpointingState(threading.local):
self.is_checkpointing = False is_checkpointing: bool = False
self.is_recomputing = False is_recomputing: bool = False
self.is_checkpointing_disabled = False is_checkpointing_disabled: bool = False
thread_local = ThreadLocal() thread_local = ThreadLocalCheckpointingState()
@contextmanager @contextmanager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment