Unverified Commit 350ccc04 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[example] opt does not depend on Titans (#1811)

parent 6fa71d65
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class barrier_context():
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
Usage:
with barrier_context():
dataset = CIFAR10(root='./data', download=True)
"""
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
# the class name is lowercase by convention
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
self.should_block = current_rank != executor_rank
self.group = gpc.get_group(parallel_mode=parallel_mode)
def __enter__(self):
if self.should_block:
dist.barrier(group=self.group)
def __exit__(self, exc_type, exc_value, exc_traceback):
if not self.should_block:
dist.barrier(group=self.group)
This diff is collapsed.
......@@ -3,3 +3,4 @@ torch >= 1.8.1
datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
accelerate == 0.13.2
......@@ -32,9 +32,9 @@ import datasets
import torch
import torch.distributed as dist
from accelerate.utils import set_seed
from context import barrier_context
from datasets import load_dataset
from packaging import version
from titans.utils import barrier_context
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from utils import colo_memory_cap
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment