Commit 571f12ef authored by Ziheng Qin's avatar Ziheng Qin Committed by binmakeswell
Browse files

[NFC] polish colossalai/nn/layer/utils/common.py code style (#983)

parent bda70b4b
...@@ -13,7 +13,8 @@ from torch import Tensor, nn ...@@ -13,7 +13,8 @@ from torch import Tensor, nn
class CheckpointModule(nn.Module): class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True, offload : bool = False):
def __init__(self, checkpoint: bool = True, offload: bool = False):
super().__init__() super().__init__()
self.checkpoint = checkpoint self.checkpoint = checkpoint
self._use_checkpoint = checkpoint self._use_checkpoint = checkpoint
...@@ -78,6 +79,7 @@ def get_tensor_parallel_mode(): ...@@ -78,6 +79,7 @@ def get_tensor_parallel_mode():
def _ntuple(n): def _ntuple(n):
def parse(x): def parse(x):
if isinstance(x, collections.abc.Iterable): if isinstance(x, collections.abc.Iterable):
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment