Commit da3f0934 authored by zhuwenwen's avatar zhuwenwen
Browse files

delete unused files

parent c4dd1fd4
from .colossalai_layer import *
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .utils import *
from .vanilla import *
from .wrapper import *
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class ParallelLayer(nn.Module):
def __init__(self):
super().__init__()
self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(
ParallelMode.DATA)
self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size(
ParallelMode.DATA)
self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank(
ParallelMode.TENSOR)
self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(
ParallelMode.TENSOR)
self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
from ._utils import partition_batch
from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear
from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch']
from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
def partition_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, dict):
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment