Unverified Commit ddcf58ca authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Revert "[sync] sync feature/shardformer with develop"

parent 24651fdd
# 🗄 Device
## 📚 Table of Contents
- [🗄 Device](#-device)
- [📚 Table of Contents](#-table-of-contents)
- [🔗 Introduction](#-introduction)
- [📝 Design](#-design)
- [🔨 Usage](#-usage)
## 🔗 Introduction
This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
## 📝 Design
This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
## 🔨 Usage
- Create a device mesh
```python
# this is the list of global ranks involved in the device mesh
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
physical_mesh_id = torch.arange(4)
mesh_shape = [2, 2]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
```
- View the mesh
```python
# view the mesh shape
# expect output
# [2, 2]
print(device_mesh.shape)
# view the logical mesh with global ranks
# expect output
# [
# [0, 1],
# [2, 3]
# ]
print(device_mesh.logical_mesh_id)
# view the number of devices in the mesh
# expect output
# 4
print(device_mesh.num_devices)
```
- Initialize the process group
```python
# intialize process group
device_mesh.init_logical_process_group()
# get the process group for a rank with respect to an axis
# this is the process group involving global ranks 0 and 2
print(device_mesh.get_process_group(axis=0, global_rank=0))
# get the ranks in the process with respect to an axis
# expect output
# [0, 2]
print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))
```
...@@ -3,19 +3,11 @@ ...@@ -3,19 +3,11 @@
with some changes. """ with some changes. """
import operator import operator
from dataclasses import dataclass
from functools import reduce from functools import reduce
from typing import Dict, List, Union from typing import List, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
@dataclass
class ProcessGroupContainer:
process_group: ProcessGroup
ranks: List[int]
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) # modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
...@@ -35,11 +27,9 @@ class DeviceMesh: ...@@ -35,11 +27,9 @@ class DeviceMesh:
during initializing the DeviceMesh instance if the init_process_group set to True. during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False) (default: False)
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
""" """
_DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
def __init__(self, def __init__(self,
physical_mesh_id: torch.Tensor, physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None, mesh_shape: torch.Size = None,
...@@ -47,140 +37,48 @@ class DeviceMesh: ...@@ -47,140 +37,48 @@ class DeviceMesh:
mesh_alpha: List[float] = None, mesh_alpha: List[float] = None,
mesh_beta: List[float] = None, mesh_beta: List[float] = None,
init_process_group: bool = False, init_process_group: bool = False,
device: str = 'cuda'): need_flatten: bool = True):
# ============================ self.physical_mesh_id = physical_mesh_id
# Physical & Logical Mesh IDs
# ============================
self._physical_mesh_id = physical_mesh_id
assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
# logical mesh ids can be obtained via two ways
# 1. provide physical mesh id and provide mesh shape
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, \
"Only one of mesh_shape and logical_mesh_id can be specified." \
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
if logical_mesh_id is None: if logical_mesh_id is None:
self.mesh_shape = mesh_shape self.mesh_shape = mesh_shape
self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape) self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
else: else:
self._logical_mesh_id = logical_mesh_id self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape self.mesh_shape = self._logical_mesh_id.shape
# ensure two things: # map global rank into logical rank
# 1. logical and physical mesh IDs should contain the same elements self.convert_map = {}
# 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
"Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
# ===============================================
# coefficient for alpha-beta communication model # coefficient for alpha-beta communication model
# alpha is latency and beta is bandwidth
# ===============================================
# if the values are not provided, we assume they are 1 for simplicity
if mesh_alpha is None: if mesh_alpha is None:
mesh_alpha = [1] * len(self.mesh_shape) mesh_alpha = [1] * len(self.mesh_shape)
if mesh_beta is None: if mesh_beta is None:
mesh_beta = [1] * len(self.mesh_shape) mesh_beta = [1] * len(self.mesh_shape)
self.mesh_alpha = tuple(mesh_alpha) self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta) self.mesh_beta = tuple(mesh_beta)
self.init_process_group = init_process_group
# ensure the alpha and beta have the same shape self.need_flatten = need_flatten
assert len(self.mesh_alpha) == len(self.mesh_beta), \ if self.init_process_group:
"mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." self.process_groups_dict = self.create_process_groups_for_logical_mesh()
if self.need_flatten and self._logical_mesh_id.dim() > 1:
# ========================= self.flatten_device_mesh = self.flatten()
# Device for Process Group # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
# ========================= # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
self._device = device # self.mesh_beta)
self._dist_backend = self._DIST_BACKEND[device]
# =========================
# Process Group Management
# =========================
# the _global_to_local_rank_mapping is structured as follows
# {
# <global-rank>: [ <local-rank-on-axis-0>, <local-rank-on-axis-1>, <local-rank-on-axis-2>, ...]
# }
self._global_to_local_rank_mapping = dict()
self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
tensor=self.logical_mesh_id)
# create process group
self._process_group_dict = {}
self._ranks_in_the_process_group = {}
self._global_rank_of_current_process = None
self._is_initialized = False
# initialize process group if specified
self._init_ranks_in_the_same_group()
self._init_process_group = init_process_group
if init_process_group:
self.init_logical_process_group()
@property @property
def shape(self) -> torch.Size: def shape(self):
"""
Return the shape of the logical mesh.
"""
return self.mesh_shape return self.mesh_shape
@property @property
def num_devices(self) -> int: def num_devices(self):
""" return reduce(operator.mul, self.physical_mesh_id.shape, 1)
Return the number of devices contained in the device mesh.
"""
return reduce(operator.mul, self._physical_mesh_id.shape, 1)
@property @property
def logical_mesh_id(self) -> torch.Tensor: def logical_mesh_id(self):
"""
Return the logical mesh id.
"""
return self._logical_mesh_id return self._logical_mesh_id
def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: def __deepcopy__(self, memo):
"""
Return the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
return self._process_group_dict[global_rank][axis]
def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
"""
Return the process groups for all axes.
Args:
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
return self._process_group_dict[global_rank]
def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
"""
Return the ranks in the process group on the specified axis.
Args:
axis (int): the axis of the process group.
global_rank (int, optional): the global rank of the process
"""
if global_rank is None:
global_rank = self._global_rank_of_current_process
return self._ranks_in_the_process_group[global_rank][axis]
def __deepcopy__(self, memo) -> "DeviceMesh":
cls = self.__class__ cls = self.__class__
result = cls.__new__(cls) result = cls.__new__(cls)
memo[id(self)] = result memo[id(self)] = result
...@@ -188,206 +86,111 @@ class DeviceMesh: ...@@ -188,206 +86,111 @@ class DeviceMesh:
if k != 'process_groups_dict': if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo)) setattr(result, k, __import__("copy").deepcopy(v, memo))
else: else:
# process group cannot be copied
# thus, we share them directly
setattr(result, k, v) setattr(result, k, v)
return result return result
def _init_global_to_logical_rank_mapping(self, def flatten(self):
mapping: Dict,
tensor: torch.Tensor,
index_list: List[int] = []) -> Dict[int, List[int]]:
""" """
Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. Flatten the logical mesh into an effective 1d logical mesh,
Args:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
tensor (torch.Tensor): the tensor that contains the logical mesh ids.
index_list (List[int])
Returns:
mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
The value is a list of integers and each integer represents the local rank in the indexed axis.
""" """
for index, inner_tensor in enumerate(tensor): flatten_mesh_shape_size = len(self.mesh_shape)
# index means the local rank in the current axis flatten_mesh_shape = [self.num_devices]
# inner_tensor refers to the processes with the same local rank return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)
def _global_rank_to_logical_rank_map(self, tensor, index_list):
'''
This method is a helper function to build convert_map recursively.
'''
for index, inner_tensor in enumerate(tensor):
if inner_tensor.numel() == 1: if inner_tensor.numel() == 1:
# if the inner_tensor only has one element, it means that self.convert_map[int(inner_tensor)] = index_list + [index]
# it already reaches the last axis
# we append its local_rank in the last axis to the index_list
# and assign to the mapping
# the value of the mapping is the the local rank at the indexed axis of the device mesh
mapping[int(inner_tensor)] = index_list + [index]
else: else:
# we recursively go into the function until we reach the last axis self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
# meanwhile, we should add the local rank in the current axis in the index_list
self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
def init_logical_process_group(self): def create_process_groups_for_logical_mesh(self):
''' '''
This method is used to initialize the logical process groups which will be used in communications This method is used to initialize the logical process groups which will be used in communications
among logical device mesh. among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise, Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors. the communication related function, such as ShapeConsistencyManager.apply will raise errors.
''' '''
# sanity check process_groups_dict = {}
assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" check_duplicate_list = []
assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
# update the global rank of the current process
self._global_rank_of_current_process = dist.get_rank()
duplicate_check_list = []
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
# find the other ranks which are in the same process group as global_rank
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
# skip duplicated process group creation
if ranks_in_same_group in duplicate_check_list:
continue
# create the process group
pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
# keep this process group in the process_groups_dict
for rank in ranks_in_same_group:
if rank not in self._process_group_dict:
self._process_group_dict[rank] = dict()
self._process_group_dict[rank][axis] = pg_handler
# update the init flag
# we only allow init for once
self._is_initialized = True
def _init_ranks_in_the_same_group(self):
"""
This method is used to initialize the ranks_in_the_same_group dictionary.
"""
# flatten the global ranks to 1D list
global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list: for global_rank in global_rank_flatten_list:
# find the other ranks which are in the same process group as global_rank process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) for axis, process_group in process_groups.items():
if axis not in process_groups_dict:
for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): process_groups_dict[axis] = []
# create dict for each rank if process_group not in check_duplicate_list:
if global_rank not in self._process_group_dict: check_duplicate_list.append(process_group)
self._ranks_in_the_process_group[global_rank] = dict() process_group_handler = dist.new_group(process_group)
process_groups_dict[axis].append((process_group, process_group_handler))
# keep this process group in the process_groups_dict return process_groups_dict
self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]: def global_rank_to_logical_rank(self, rank):
""" return self.convert_map[rank]
Return the local rank of the given global rank in the logical device mesh.
Args: def global_rank_to_process_groups_with_logical_rank(self, rank):
rank (int): the global rank in the logical device mesh.
axis (int): the axis of the logical device mesh.
"""
local_ranks = self._global_to_local_rank_mapping[rank]
if axis:
return local_ranks[axis]
else:
return local_ranks
def _collate_global_ranks_in_same_process_group(self, global_rank):
''' '''
Give a global rank and return all global ranks involved in its associated process group in each axis. Give a global rank and return all logical process groups of this rank.
for example:
Example: physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
```python # [[0, 1, 2, 3],
sphysical_mesh_id = torch.arange(0, 16) # [4, 5, 6, 7],
mesh_shape = (4, 4) # [8, 9, 10,11],
# [12,13,14,15]]
# logical mesh will look like device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
# [[0, 1, 2, 3], print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
# [4, 5, 6, 7], output:
# [8, 9, 10,11], # key is axis name
# [12,13,14,15]] # value is a list of logical ranks in same axis with rank 0
{0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
print(device_mesh.collate_global_ranks_in_same_process_group(0))
# key is axis name
# value is a list of global ranks in same axis with rank 0
# output will look like
# {
0: [0, 4, 8, 12],
1: [0, 1, 2, 3]
# }
''' '''
# We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping process_groups = {}
# for self._global_to_local_rank_mapping for d in range(self.logical_mesh_id.dim()):
# the key is the global rank for replacer in range(self.logical_mesh_id.shape[d]):
# the value is the list of local ranks corresponding to the global rank with respect of different axes if d not in process_groups:
# we can see the list of local ranks as the process coordinates for simplicity process_groups[d] = []
# the key and value are all unique, therefore, process_group_member = self.convert_map[rank].copy()
# we can also to use the coordinates to find the global rank process_group_member[d] = replacer
process_groups[d].append(process_group_member)
# ========================================================================= return process_groups
# Step 1
# find all the process_coordinates for processes in the same process group def global_rank_to_process_groups_with_global_rank(self, rank):
# as the given global rank '''
# ========================================================================= Give a global rank and return all process groups of this rank.
for example:
# each physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
processes_in_the_same_process_group = {} mesh_shape = (4, 4)
# [[0, 1, 2, 3],
for dim in range(self.logical_mesh_id.dim()): # [4, 5, 6, 7],
# iterate over the dimension size so that we can include all processes # [8, 9, 10,11],
# in the same process group in the given axis # [12,13,14,15]]
# the _local_rank refers to the local rank of the current process device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
for _local_rank in range(self.logical_mesh_id.shape[dim]): print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
output:
# if this dimension is not initailized yet, # key is axis name
# initialize it with an empty array # value is a list of global ranks in same axis with rank 0
if dim not in processes_in_the_same_process_group: {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
processes_in_the_same_process_group[dim] = [] '''
logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
# get the local rank corresponding to the global rank process_groups = {}
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() for dim, logical_ranks in logical_process_groups.items():
process_groups[dim] = []
# replace the local rank in the given dimension with the for logical_rank in logical_ranks:
# lcoal rank of the current process iterated for g_rank, l_rank in self.convert_map.items():
process_coordinates[dim] = _local_rank if l_rank == logical_rank:
processes_in_the_same_process_group[dim].append(process_coordinates) process_groups[dim].append(g_rank)
return process_groups
# =================================================================
# Step 2
# Use local rank combination to find its corresponding global rank
# =================================================================
# the key of the dict is the axis
# the value is the list of global ranks which are in the same process group as the given global rank
global_pg_ranks = {}
for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
global_pg_ranks[dim] = []
for process_coordinates in coordinates_of_all_processes:
# find the global rank by local rank combination
for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
if process_coordinates == _process_coordinates:
global_pg_ranks[dim].append(_global_rank)
return global_pg_ranks
def flatten(self):
"""
Flatten the logical mesh into an effective 1d logical mesh,
"""
flatten_mesh_shape_size = len(self.mesh_shape)
flatten_mesh_shape = [self.num_devices]
return DeviceMesh(self._physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self._init_process_group)
def all_gather_cost(self, num_bytes, mesh_dim): def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim] num_devices = self.logical_mesh_id.shape[mesh_dim]
...@@ -409,3 +212,38 @@ class DeviceMesh: ...@@ -409,3 +212,38 @@ class DeviceMesh:
penalty_factor = num_devices / 2.0 penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
class FlattenDeviceMesh(DeviceMesh):
def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
super().__init__(physical_mesh_id,
mesh_shape,
mesh_alpha,
mesh_beta,
init_process_group=False,
need_flatten=False)
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
self.mesh_alpha = max(self.mesh_alpha)
self.mesh_beta = min(self.mesh_beta)
# Different from original process_groups_dict, rank_list is not stored
self.process_number_dict = self.create_process_numbers_for_logical_mesh()
def create_process_numbers_for_logical_mesh(self):
'''
Build 1d DeviceMesh in column-major(0) and row-major(1)
for example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
num_devices = reduce(operator.mul, self.mesh_shape, 1)
process_numbers_dict = {}
process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
return process_numbers_dict
def mix_gather_cost(self, num_bytes):
num_devices = reduce(operator.mul, self.mesh_shape, 1)
return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
from types import MethodType from types import MethodType
from typing import Callable, Dict, Optional, Union from typing import Callable, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -8,9 +8,8 @@ from torch import Tensor ...@@ -8,9 +8,8 @@ from torch import Tensor
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor from colossalai._analyzer._subclasses import MetaTensor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.tensor.d_tensor.layout import Layout
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [ _NORMAL_FACTORY = [
...@@ -173,7 +172,7 @@ class LazyTensor(torch.Tensor): ...@@ -173,7 +172,7 @@ class LazyTensor(torch.Tensor):
self.clean() self.clean()
return _convert_cls(self, target) return _convert_cls(self, target)
def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: def distribute(self, layout: Layout) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args: Args:
...@@ -184,7 +183,7 @@ class LazyTensor(torch.Tensor): ...@@ -184,7 +183,7 @@ class LazyTensor(torch.Tensor):
""" """
target = self._materialize_data() target = self._materialize_data()
self.clean() self.clean()
local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor local_tensor = DTensor(target, layout).local_tensor
return _convert_cls(self, local_tensor) return _convert_cls(self, local_tensor)
def clean(self) -> None: def clean(self) -> None:
...@@ -537,10 +536,7 @@ class LazyInitContext: ...@@ -537,10 +536,7 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod @staticmethod
def distribute(module: nn.Module, def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args: Args:
...@@ -550,7 +546,7 @@ class LazyInitContext: ...@@ -550,7 +546,7 @@ class LazyInitContext:
""" """
def apply_fn(name: str, p: LazyTensor): def apply_fn(name: str, p: LazyTensor):
p.distribute(device_mesh, sharding_spec_dict[name]) p.distribute(layout_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
try: try:
......
# ⚡️ ShardFormer
## 📚 Table of Contents
- [⚡️ ShardFormer](#️-shardformer)
- [📚 Table of Contents](#-table-of-contents)
- [🔗 Introduction](#-introduction)
- [🔨 Usage](#-usage)
- [🔮 Simple example](#-simple-example)
- [💡 Policy](#-policy)
- [😊 Module](#-module)
## 🔗 Introduction
**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background.
## 🔨 Usage
The sample API usage is given below:
``` python
from colossalai.shardformer import shard_model
from transformers import BertForMaskedLM
# create huggingface model as normal
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# make the huggingface model paralleled to ShardModel
# auto policy:
sharded_model = shard_model(model)
# custom policy:
from xxx import <POLICYCLASS>
sharded_model = shard_model(model, <POLICYCLASS>)
# do angthing as normal
...
```
## 🔮 Simple example
``` shell
# inference
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference
# train
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train
```
## 💡 Policy
If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model.
You should do:
1. Inherit Policy class
2. Overwrite `argument_policy` method
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified.
- `attr_dict` is dict contains all the attributes need to be modified in this layer.
- `param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer.
3. Overwrite `inject_policy` method (Optional)
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
4. Overwrite or add the param functions
- These functions use a suffix to record the path of weight or bias for the layer.
- The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively.
5. Overwrite `binding_policy` (Optional)
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
- This function will return a dict, the key and value are the suffix of weight need to be binded.
More details can be found in shardformer/policies/basepolicy.py
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
CustomPolicy(Policy):
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Args:
model_config (:class:`tansformer.Config`): The config of transformer model
shard_config (:class:`ShardConfig`): The config for sharding model
Return:
Dict for the modify policy,
::
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
...
}
"""
raise NotImplementedError
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
r"""
Return the dict for the inject model
Return:
The injected model, key is the original model and value is the new shardmodel
::
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return ()
@staticmethod
def binding_policy() -> Dict:
r"""
Return the dict for the binding model
Return:
This method should return the binding relationship for some layers share the weight or bias,
the key and value is the suffix of the weight or bias of the model
::
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return NotImplementedError
@staticmethod
def attn_in() -> List:
"""
Attention qkv layer
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
@staticmethod
def attn_out() -> List:
"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_in() -> List:
"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_out() -> List:
"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def embedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def unembedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
```
## 😊 Module
1. Flowchart
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/shardformer_flowchart.png" width="600" />
</p>
2. Important Modules
- CLASS `shard_model`:
This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model.
- CLASS `Layer`:
Parameters:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
- ignore (bool): Whether to ignore this layer if it is not in the model
This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.
CLASS `Col_Layer(Layer)`:
- gather_output (bool): Whether to gather the output of the layer
This class inherited from `Layer`, representing the layer will be sliced along column.
CLASS `Row_Layer(Layer)`:
This class inherited from `Layer`, representing the layer will be sliced along row.
- CLASS `Policy`:
In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class.
- `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`......
These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions.
- `Policy.argument_policy()`
In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach.
- `Policy.inject_policy()`
This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else.
- `Policy.binding_policy()`
This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters.
- CLASS `ModelSharder(model, policy)`:
This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model.
- `ModelShard.inject_model()`
This function is used to inject the model to modify the forward and backward progress.
- `ModelShard.replace_layer()`
This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication.
- `ModelShard.bind_layer()`
This function is used to help different layers share weight or bias.
- CLASS `Slicer`:
This class is used to slice tensor according to policy.
3. DistCrossEntropy Loss
- Overview
In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$
alse can be represented as:
$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$
- Step
- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large
- Get a mask to mask the logits not in the local device
- Caculate the loss according to the second formula
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm
Args:
input: input matrix.
weight: weight matrix.
bias: bias matrix.
normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability
"""
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
class DistCrossEntropy(Function):
r"""
Overwrite the forward and backward function to calculate the cross entropy loss before gather
Args:
Function (:class:`torch.autograd.Function`): default
"""
@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
and can be rewrite as:
loss = log(sum(exp(x[i])) - x[class]
To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i]
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]
Returns:
:class:`torch.Tensor`: The cross entropy loss
"""
# get the max
logits_max = torch.max(vocab_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)
# minus the max to avoid the result of sum of exp is too large and the log is nan
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank()
world_size = dist.get_world_size()
global_vocab_size = partition_vocab_size * world_size
# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_shreshold = rank * delta
up_shreshold = down_shreshold + delta
mask = (target < down_shreshold) | (target >= up_shreshold)
masked_target = target.clone() - down_shreshold
masked_target[mask] = 0
# reshape the logist and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device),
masked_target_1d]
pred_logits_1d = pred_logits_1d.clone().contiguous()
pred_logits = pred_logits_1d.view_as(target)
pred_logits[mask] = 0.0
# allreduce the get all x(i,y)
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)
# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
loss = torch.log(sum_exp_logits) - pred_logits
loss = torch.sum(loss).div_(loss.numel())
# caculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# retrieve the saved tensors
exp_logits, mask, masked_target_1d = ctx.saved_tensors
# use exp logits as the input grad
grad_logits = exp_logits
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float()
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels)
import os
import time
from contextlib import contextmanager
import torch
import torch.nn as nn
class SeedManager:
"""
This class is a random state manager to change random state for different random seed.
"""
def __init__(self):
original_state = torch.cuda.get_rng_state()
seed = int(f"{int(time.time())}{os.environ['RANK']}")
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)
def set_mode(self, rng_state):
torch.cuda.set_rng_state(rng_state)
def get_current_mode(self):
current_state = torch.cuda.get_rng_state()
return current_state
@contextmanager
def dropout_mode(self):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_mode = self.get_current_mode()
yield self.set_mode(self.dropout_state)
finally:
self.dropout_state = self.get_current_mode()
self.set_mode(current_mode)
_seed_manager = SeedManager()
class Dropout1D(nn.Dropout):
def __init__(self, p=0.5, inplace=False):
super().__init__(p, inplace)
def forward(self, input):
with _seed_manager.dropout_mode():
input = super().forward(input)
return input
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from collections import OrderedDict
from typing import Callable, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule
from colossalai.nn.layer.parallel_1d._utils import (
gather_forward_split_backward,
get_parallel_input,
reduce_grad,
reduce_input,
set_parallel_input,
split_forward_gather_backward,
)
from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition
from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device
from ._operation import linear_with_async_comm
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
# @LAYERS.register_module
class Linear1D(ColossalaiModule):
r"""Linear layer for 1D parallelism.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
gather_output (bool, optional): Whether to call all-gather on output, defaults to False.
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
parallel_input = get_parallel_input()
if not parallel_input and not gather_output:
layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
else:
layer = Linear1D_Row(in_features,
out_features,
bias=bias,
dtype=dtype,
parallel_input=parallel_input,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
super().__init__(layer)
# @LAYERS.register_module
class LayerNorm1D(ColossalaiModule):
r"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
super()._save_to_state_dict(destination, prefix, keep_vars)
# @LAYERS.register_module
class Classifier1D(ParallelLayer):
r"""RowLinear with given weight. Classifier of 1D parallelism.
Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
if bias:
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
env.vocab_parallel = False
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
if self.has_weight:
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
})
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if self.bias is not None:
output = output + self.bias
return output
# @LAYERS.register_module
class VocabParallelClassifier1D(ParallelLayer):
r"""ColLinear with given weight. Classifier of 1D parallelism.
Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.gather_output = gather_output
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs))
self.has_weight = True
if bias:
self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
env.vocab_parallel = True
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
})
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
return output
# @LAYERS.register_module
class Linear1D_Col(ParallelLayer):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
is_parallel_output = not self.gather_output
set_parallel_input(is_parallel_output)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
})
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
# output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
if self.skip_bias_add:
return output, self.bias
else:
return output
# @LAYERS.register_module
class Linear1D_Row(ParallelLayer):
r""" Linear layer with row parallelism
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
super().__init__()
self.stream_chunk_num = stream_chunk_num
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def chunk_weight(self):
self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
# bias
if self.bias is not None:
bias = state_dict.pop(bias_key, None)
if bias is not None:
local_state[bias_key] = bias
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
})
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
if self.stream_chunk_num > 1:
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
group=gpc.get_group(ParallelMode.PARALLEL_1D),
async_op=True)
handle_list.append(handle)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias
# @LAYERS.register_module
class Embedding1D(ParallelLayer):
r"""Embedding for 1D parallelism.
Args:
num_embeddings (int): number of embeddings.
embedding_dim (int): dimension of embedding.
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
therefore, the embedding vector at padding_idx is not updated during training,
i.e. it remains as a fixed “pad”, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
of frequency of the words in the mini-batch. Default False.
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
More details about ``args`` and ``kwargs`` could be found in
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: -1},
partition_states={weight_key: True})
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: -1},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
return output
# @LAYERS.register_module
class VocabParallelEmbedding1D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension.
Args:
num_embeddings (int): number of embeddings.
embedding_dim (int): dimension of embedding.
padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
therefore, the embedding vector at padding_idx is not updated during training,
i.e. it remains as a fixed “pad”, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
weight_initializer (:class:`typing.Callable`, optional):
he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
of frequency of the words in the mini-batch. Default False.
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
More details about ``args`` and ``kwargs`` could be found in
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
if weight is not None:
local_state[weight_key] = weight
local_state = partition_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True})
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True},
keep_vars=keep_vars)
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
return output
# @LAYERS.register_module
class Dropout1D(ParallelLayer):
"""Dropout layer of 1D parallelism.
Args:
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
"""
def __init__(self, p: float = 0.5, inplace: bool = False):
super().__init__()
self.parallel_input = get_parallel_input()
self.p = p
self.inplace = inplace
def forward(self, input_: Tensor) -> Tensor:
if self.parallel_input:
with seed(ParallelMode.TENSOR):
output = F.dropout(input_, self.p, self.training, self.inplace)
else:
output = F.dropout(input_, self.p, self.training, self.inplace)
return output
# @LAYERS.register_module
class PatchEmbedding1D(ColossalaiModule):
"""
2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: size of embedding
:type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
:param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The initializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The initializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: torch.dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
embed = VanillaPatchEmbedding(img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer)
super().__init__(embed)
def _load_from_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed']
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
for key in param_keys:
param = state_dict.pop(key, None)
if param is not None:
local_state[key] = param
local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D)
super()._load_from_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
super()._save_to_state_dict(destination, prefix, keep_vars)
from typing import Any, Dict, List, Type
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import BertForMaskedLM
from transformers.models.bert.modeling_bert import MaskedLMOutput
from ..layer.dist_crossentropy import applyDistCrossEntropy
class BertForMaskedLM_(BertForMaskedLM):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
# print("[Inject OK] Injected forward method")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:
masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels)
# if labels is not None:
# loss_fct = CrossEntropyLoss() # -100 index = padding token
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
import torch.nn as nn
def build_policies():
r"""
Build the policies for the model
Return:
The dict for the policies
"""
auto_policy_dict = {}
from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers import GPT2Model
from .gpt2 import GPT2Policy
auto_policy_dict[GPT2Model] = GPT2Policy
from transformers import GPT2LMHeadModel
from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return auto_policy_dict
def get_autopolicy(model: nn.Module):
r"""
Return the auto policy for the model
Args:
model (:class:`nn.Module`): The model to get the auto policy
Return:
:class:`Policy`: The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
)
return policy
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)
# part of code modified from https://github.com/tunib-ai/parallelformers
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
@dataclass
class Argument:
r"""
The argument class for the policy
Args:
attr_dict (Dict[str, Any]): The dict for the param setting
param_funcs (:class:`List[Callable]`): The list for the param functions
"""
attr_dict: Dict[str, Any]
param_funcs: List[Callable]
@dataclass
class Layer:
r"""
The layer object for the policy
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
weight: str = None
bias: str = None
replace_layer: Any = None
ignore: bool = False
reversed: bool = False
n_cast: int = None
@dataclass
class Col_Layer(Layer):
r"""
Class for col shard layer in MegatronLM
Args:
gather_output (bool): Whether to gather the output of the layer
"""
gather_output: bool = False
@dataclass
class Row_Layer(Layer):
r"""
Class for col shard layer in MegatronLM
"""
pass
class Policy():
r"""
The base class for all the policies
For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model.
AutoPolicy:
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
and overwrite the method like ``inject_policy`` to modify the forward and backward process.
CustomPolicy:
If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
class for the example.
"""
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Args:
model_config (:class:`tansformer.Config`): The config of transformer model
shard_config (:class:`ShardConfig`): The config for sharding model
Return:
Dict for the modify policy,
::
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
...
}
"""
raise NotImplementedError
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
r"""
Return the dict for the inject model
Return:
The injected model, key is the original model and value is the new shardmodel
::
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return None
@staticmethod
def binding_policy() -> Dict:
r"""
Return the dict for the binding model
Return:
This method should return the binding relationship for some layers share the weight or bias,
the key and value is the suffix of the weight or bias of the model
::
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return None
@staticmethod
def attn_in() -> List:
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
in ``Layer`` object can refer to the ``Layer`` class.
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
@staticmethod
def attn_out() -> List:
r"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_in() -> List:
r"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def mlp_out() -> List:
r"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def embedding() -> List:
r"""
Partially slice the embedding layer
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
@staticmethod
def unembedding() -> List:
r"""
Partially slice the embedding layer
Return:
List[Layer]: List of layer object
"""
return None
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class BertPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
return {
BertLayer:
Argument(
attr_dict={
# 1. shard hidden size
"attention.self.all_head_size": config.hidden_size // world_size,
"crossattention.self.all_head_size": config.hidden_size // world_size,
# 2. shard number of heads
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
},
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
BertEmbeddings:
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
},
param_funcs=[
BertPolicy.embedding,
]),
BertLMPredictionHead:
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
},
param_funcs=[
BertPolicy.unembedding,
])
}
@staticmethod
def binding_policy() -> Dict:
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
@staticmethod
def attn_in() -> List:
return [
Col_Layer(
weight="attention.self.query.weight",
bias="attention.self.query.bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="attention.self.key.weight",
bias="attention.self.key.bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="attention.self.value.weight",
bias="attention.self.value.bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
weight="crossattention.self.query.weight",
bias="crossattention.self.query.bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
weight="crossattention.self.key.weight",
bias="crossattention.self.key.bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
Col_Layer(
weight="crossattention.self.value.weight",
bias="crossattention.self.value.bias",
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
]
@staticmethod
def attn_out() -> List:
return [
Row_Layer(
weight="attention.output.dense.weight",
bias="attention.output.dense.bias",
replace_layer=col_nn.Linear1D_Row,
),
Row_Layer(
weight="crossattention.output.dense.weight",
bias="crossattention.output.dense.bias",
replace_layer=col_nn.Linear1D_Row,
ignore=True,
),
]
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(
weight="intermediate.dense.weight",
bias="intermediate.dense.bias",
replace_layer=col_nn.Linear1D_Col,
),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(
weight="output.dense.weight",
bias="output.dense.bias",
replace_layer=col_nn.Linear1D_Row,
),
]
@staticmethod
def embedding() -> List:
return [Col_Layer(
weight="word_embeddings.weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
@staticmethod
def unembedding() -> List:
return [
Col_Layer(
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
# gather_output=True,
)
]
from transformers import BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_)
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Dict:
return {}
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# _ = BertForMaskedLMPolicy(model)
# print(isinstance(model,list(_.inject_policy().keys())[0]))
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class GPT2Policy(Policy):
@staticmethod
def argument_policy(config, world_size):
return {
GPT2Model:
Argument(attr_dict={}, param_funcs=[
GPT2Policy.embedding,
]),
GPT2Block:
Argument(
attr_dict={
# 1. reduce hidden size
"attn.embed_dim": config.hidden_size // world_size,
"attn.split_size": config.hidden_size // world_size,
"crossattention.embed_dim": config.hidden_size // world_size,
"crossattention.split_size": config.hidden_size // world_size,
# 2. reduce number of heads
"attn.num_heads": config.num_attention_heads // world_size,
"crossattention.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[
GPT2Policy.attn_in,
GPT2Policy.attn_out,
GPT2Policy.mlp_in,
GPT2Policy.mlp_out,
]),
}
@staticmethod
def attn_in() -> List:
return [
Col_Layer(weight="attn.c_attn.weight",
bias="attn.c_attn.bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.c_attn.weight",
bias="crossattention.c_attn.bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.q_attn.weight",
bias="crossattention.q_attn.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
]
@staticmethod
def attn_out() -> List:
return [
Row_Layer(weight="attn.c_proj.weight",
bias="attn.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(weight="crossattention.c_proj.weight",
bias="crossattention.c_proj.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(weight="mlp.c_proj.weight",
bias="mlp.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def embedding() -> List:
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
from transformers import GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
@staticmethod
def argument_policy(config, world_size):
base_argument = GPT2Policy.argument_policy(config, world_size)
argument = {
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
GPT2LMHeadModelPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def unembedding() -> List:
return [
Col_Layer(weight="lm_head.weight",
bias="lm_head.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]
from .shard_config import ShardConfig
from .sharder import ModelSharder, shard_model
from .slicer import Slicer
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']
from dataclasses import dataclass
__all__ = ['ShardConfig']
@dataclass
class ShardConfig:
"""
The config for sharding the huggingface model for test
"""
rank: int
fp16: bool = True
num_gpus: int = 2
world_size: int = 2
backend = "nccl"
verbose: str = 'simple'
seed: int = None
require_grad: bool = False
master_addr: str = "127.0.0.1"
master_port: int = 29500
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment