Commit 61e68783 authored by zbian's avatar zbian Committed by アマデウス
Browse files

fixed using zero with tp cannot access weight correctly

parent eb5cf943
import torch.nn as nn
from torch import Tensor
from ..parallel_2d._operation import split_batch_2d
from ..parallel_2p5d._operation import split_batch_2p5d
from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
def partition_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, dict):
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_
class ColossalaiModule(nn.Module):
def __init__(self, module: nn.Module, **kwargs):
super().__init__()
# copy values
self.__dict__ = module.__dict__.copy()
# copy methods
for name, attr in module.__class__.__dict__.items():
if name not in ['__init__', 'forward'] and callable(attr):
setattr(self, name, getattr(module, name))
self._forward_func = module.forward
for k, v in kwargs.items():
setattr(self, k, v)
def forward(self, *args):
return self._forward_func(*args)
import torch.nn as nn
from torch import Tensor
from ..parallel_2d._operation import split_batch_2d
from ..parallel_2p5d._operation import split_batch_2p5d
from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
def partition_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, dict):
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_
class ColossalaiModule(nn.Module):
def __init__(self, module: nn.Module, **kwargs):
super().__init__()
self.module = module
for k, v in kwargs.items():
setattr(self, k, v)
def __getattr__(self, name: str):
if name == 'module':
return super().__getattr__(name)
elif hasattr(self.module, name):
return getattr(self.module, name)
elif name in self.__dict__:
return self.__dict__[name]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
def forward(self, *args):
return self.module(*args)
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
from ._utils import ColossalaiModule
class Dropout(ColossalaiModule):
"""Dropout layer of colossalai.
Args:
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
"""
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == "1d":
drop = Dropout1D(p, inplace)
else:
drop = nn.Dropout(p, inplace)
super().__init__(drop, tensor_parallel=tensor_parallel)
def forward(self, *args):
if self.tensor_parallel in [None, '1d']:
return self._forward_func(*args)
else:
with seed(ParallelMode.TENSOR):
return self._forward_func(*args)
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
from ._utils import ColossalaiModule
class Dropout(ColossalaiModule):
"""Dropout layer of colossalai.
Args:
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
"""
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == "1d":
drop = Dropout1D(p, inplace)
else:
drop = nn.Dropout(p, inplace)
super().__init__(drop, tensor_parallel=tensor_parallel)
def forward(self, *args):
if self.tensor_parallel in [None, '1d']:
return super().forward(*args)
else:
with seed(ParallelMode.TENSOR):
return super().forward(*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment