Unverified Commit 46c009db authored by Hakjin Lee's avatar Hakjin Lee Committed by GitHub
Browse files

[format] Run lint on colossalai.engine (#3367)

parent b9231390
from typing import Iterable, List
import torch.nn as nn
from typing import List
from colossalai.engine import BaseGradientHandler
from typing import Iterable
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
from colossalai.engine import BaseGradientHandler
from ._gradient_accumulation import (
GradAccumDataloader,
GradAccumGradientHandler,
GradAccumLrSchedulerByStep,
GradAccumOptimizer,
)
__all__ = [
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
......
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
......
......@@ -4,9 +4,10 @@ from collections import defaultdict
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from ._base_gradient_handler import BaseGradientHandler
......
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
......
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
......
from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from ._non_pipeline_schedule import NonPipelineSchedule
from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
......@@ -2,10 +2,10 @@
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Callable, Iterable
import torch
from typing import Iterable, Callable
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Iterable
import inspect
from typing import Callable, Iterable
import torch
import inspect
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
from typing import Callable
from ._base_schedule import BaseSchedule
class NonPipelineSchedule(BaseSchedule):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment