Unverified Commit b5f9e37c authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
parent 32e7f994
......@@ -4,8 +4,8 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context as gpc
from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.legacy.registry import LOSSES
......
import torch
from torch import nn
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
......
......@@ -5,7 +5,7 @@ from typing import Iterable, Optional, Set
import torch
import torch.distributed as dist
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.legacy.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import is_ddp_ignored
from .reducer import Reducer
......@@ -34,8 +34,8 @@ class ColoDDP(torch.nn.Module):
"""Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.
Example:
>>> from colossalai.core import global_context as gpc
>>> from colossalai.context import ParallelMode
>>> from colossalai.legacy.core import global_context as gpc
>>> from colossalai.legacy.context import ParallelMode
>>> model = torch.nn.Linear(20, 1)
>>> pg = ProcessGroup(tp_degree = world_size//2)
>>> model = ColoDDP(model, pg)
......
......@@ -4,7 +4,8 @@ import torch
import torch.nn.functional as F
from colossalai.legacy.nn._ops._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
from colossalai.legacy.tensor import ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
from colossalai.tensor import ColoParameter, ColoTensor
from .cache_mgr import CachedParamMgr, EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
......
......@@ -6,7 +6,7 @@ import torch.distributed as dist
import torch.nn.functional as F
from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
from colossalai.tensor import ProcessGroup
from colossalai.legacy.tensor import ProcessGroup
from .cache_mgr import EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
......
......@@ -7,7 +7,7 @@ import torch.nn as nn
from torch.profiler import record_function
from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
from colossalai.tensor import ProcessGroup
from colossalai.legacy.tensor import ProcessGroup
from .cache_mgr import EvictionStrategy
from .cached_embedding import CachedEmbeddingBag
......
from typing import Dict, List
from colossalai.tensor import ComputePattern
from colossalai.tensor.distspec import _DistSpec
from colossalai.legacy.tensor import ComputePattern
from colossalai.legacy.tensor.distspec import _DistSpec
class ColoModule(object):
......
from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from .colo_module import ColoModule
......
from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
from .colo_module import ColoModule
......
......@@ -2,7 +2,8 @@ from typing import Dict
import torch
from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec
from colossalai.legacy.tensor import ComputeSpec, ProcessGroup, distspec
from colossalai.tensor import ColoParameter
from . import ColoModule
......
from .layer_spec import LayerSpec
from .pipelinable import PipelinableContext, PipelinableModel
__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec']
import torch
from colossalai.utils.model.utils import call_to_str
class LayerSpec:
"""
"""
def __init__(self, typename, *module_args, **module_kwargs):
......@@ -52,4 +54,4 @@ class LayerSpec:
return self._param_count
def reset_param_count(self):
self._param_count = 0
\ No newline at end of file
self._param_count = 0
from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal
from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
\ No newline at end of file
__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal']
from .fx import get_topology as get_fx_topology
__all__ = ['get_fx_topology']
\ No newline at end of file
__all__ = ['get_fx_topology']
from torch.fx.graph_module import GraphModule
from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
import torch
from torch.fx.graph_module import GraphModule
from colossalai.legacy.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo
def partition_name_to_id(partition_name, is_input=False, is_output=False):
if is_input:
......@@ -12,6 +14,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id
# There are two kinds of def in fx.graph
# 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value.
# e.g. submod1 = call_module(...)
......@@ -20,6 +23,8 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
# 2. direct_use & direct_def, which means the output is used by next partition directly.
# e.g. submod1 = call_module(...)
# submod2 = call_module(submod1, ...)
def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
......@@ -45,9 +50,10 @@ def find_input_in_partition(node, partitions, input_partitions=None):
partition_id = partition_name_to_id(partition.name)
p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset)
return p_input_val
return p_input_val
def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
......@@ -70,7 +76,7 @@ def find_output_in_partition(node, partitions, output_partitions=None):
if arg == user:
p_output_val.add(partition_id=partition_id, offset=i)
break
# user is output
if output_partitions is not None:
output_node = output_partitions[0]
......@@ -84,10 +90,11 @@ def find_output_in_partition(node, partitions, output_partitions=None):
break
return p_output_val
def get_topology(gm: GraphModule):
topo = Topo()
topo_output_partition = Partition()
input_partitions = []
partitions = []
output_partitions = []
......@@ -109,7 +116,7 @@ def get_topology(gm: GraphModule):
topo_input_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=0, partition=topo_input_partition)
topo.set_input_partition_id(partition_id=0)
for i, partition in enumerate(partitions):
topo_mid_partition = Partition()
# set input for submodule
......@@ -131,15 +138,16 @@ def get_topology(gm: GraphModule):
for user in partition.users:
cur_node = user
p_output_val = find_output_in_partition(cur_node, partitions, output_partitions)
topo_mid_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=i+2, partition=topo_mid_partition)
topo_mid_partition.add_output_val(p_output_val)
topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition)
# set input for output_partition
for partition in output_partitions:
topo_output_partition = Partition()
torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val(
find_input_in_partition(n, partitions, input_partitions)))
torch.fx.graph.map_arg(
partition.args[0],
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition_id(partition_id=1)
return topo
\ No newline at end of file
return topo
from typing import Dict, List
from dataclasses import dataclass
from typing import Dict, List
# This file includes data structure used by Pipeline Middleware.
@dataclass
class ValPosition:
partition_id: int
offset: int
def __str__(self) -> str:
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
return res
def __repr__(self) -> str:
return self.__str__()
class PartitionInputVal(object):
def __init__(self, partition_id, offset) -> None:
# every input from which partition_id and which offset
val_pos = ValPosition(partition_id, offset)
self._from_partition_and_offset: ValPosition = val_pos
def get(self):
return self._from_partition_and_offset
def __str__(self) -> str:
res = ''
res += f'<-({self._from_partition_and_offset})'
return res
def __repr__(self) -> str:
return self.__str__()
class PartitionOutputVal(object):
def __init__(self) -> None:
# every output to which partition_id and which offset
self._to_partition_and_offset: List[ValPosition] = []
def add(self, partition_id, offset):
val_pos = ValPosition(partition_id, offset)
self._to_partition_and_offset.append(val_pos)
def get(self):
return self._to_partition_and_offset
def __str__(self) -> str:
res = ''
res += '->('
......@@ -51,27 +56,29 @@ class PartitionOutputVal(object):
res += f'{val_pos},'
res += ')'
return res
def __repr__(self) -> str:
return self.__str__()
class Partition(object):
def __init__(self) -> None:
self._input_vals: List[PartitionInputVal] = []
self._output_vals: List[PartitionOutputVal] = []
def add_input_val(self, input_val: PartitionInputVal):
self._input_vals.append(input_val)
def add_output_val(self, output_val: PartitionOutputVal):
self._output_vals.append(output_val)
def get_input_vals(self):
return self._input_vals
def get_output_vals(self):
return self._output_vals
# get the output offsets sent to dst_partition_id
def get_output_offsets(self, dst_partition_id):
res = []
......@@ -80,9 +87,9 @@ class Partition(object):
for val_pos in outputs:
if val_pos.partition_id == dst_partition_id:
res.append(offset)
return res
# get all input dst partition_ids
def get_input_partition_ids(self):
res = []
......@@ -91,7 +98,7 @@ class Partition(object):
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
# get all output dst partition_ids
def get_output_partition_ids(self):
res = []
......@@ -101,24 +108,25 @@ class Partition(object):
if val_pos.partition_id not in res:
res.append(val_pos.partition_id)
return res
def __str__(self) -> str:
res = ''
res += f' input:\n'
res += f' length:{len(self._input_vals)}\n'
for i, input_val in enumerate(self._input_vals):
res += f' offset={i}:{input_val}\n'
res += f' output:\n'
res += f' length:{len(self._output_vals)}\n'
for i, output_val in enumerate(self._output_vals):
res += f' offset={i}:{output_val}\n'
return res
def __repr__(self) -> str:
return self.__str__()
# This class is a middleware between partition splitter
# and Pipeline Scheduler. It records the graph info about
# partition input/output and provides it to scheduler.
......@@ -132,42 +140,43 @@ class Partition(object):
# _input_partition_id: the key represents input_partition
# _output_partition_id: the key represents output_partition
class Topo(object):
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
self._partitions: Dict[int, Partition] = {}
self._input_partition_id = input_partition_id
self._output_partition_id = output_partition_id
def set_input_partition_id(self, partition_id: int):
self._input_partition_id = partition_id
def set_output_partition_id(self, partition_id: int):
self._output_partition_id = partition_id
def get_input_partition_id(self):
return self._input_partition_id
def get_output_partition_id(self):
return self._output_partition_id
def set_partitions(self, partition_id: int, partition: Partition):
self._partitions[partition_id] = partition
def get_mid_partitions(self):
res = {} #{partition_id: Partition}
res = {} #{partition_id: Partition}
for partition_id, partition in self._partitions.items():
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
continue
res[partition_id] = partition
return res
def get_mid_partition_ids(self):
return list(self.get_mid_partitions().keys())
def get_input_partition(self):
if self._input_partition_id is not None:
return self._partitions[self._input_partition_id]
return None
def get_output_partition(self):
if self._output_partition_id is not None:
return self._partitions[self._output_partition_id]
......@@ -175,7 +184,7 @@ class Topo(object):
def get_partition_by_id(self, partition_id):
return self._partitions[partition_id]
def __str__(self) -> str:
res = ''
if len(self._partitions) == 0:
......@@ -186,21 +195,20 @@ class Topo(object):
res += '{\n'
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
res += '}\n'
mid_parts = self.get_mid_partitions()
for i, (partition_id, part) in enumerate(mid_parts.items()):
res += '{\n'
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
res += '}\n'
output_part = self.get_output_partition()
if output_part is not None:
res += '{\n'
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
res += '}\n'
return res
def __repr__(self) -> str:
return self.__str__()
\ No newline at end of file
import inspect
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .layer_spec import LayerSpec
from .utils import (
build_kwargs_for_function,
build_kwargs_for_module,
call_module,
customized_partition,
exec_func_with_kwargs,
exec_funcs_with_kwargs,
partition_balanced,
partition_uniform,
......@@ -135,8 +131,10 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
children_name = []
for child in self._root_children:
layer_spec = self._layer_spec_dict[id(child)]
if layer_spec.typename in (torch.nn.modules.container.ModuleList,
torch.nn.modules.container.Sequential):
if layer_spec.typename in (
torch.nn.modules.container.ModuleList,
torch.nn.modules.container.Sequential,
):
for child_in_container in layer_spec.children:
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
for name, module in self._model.named_modules():
......@@ -155,9 +153,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
named_modules = dict(self._model.named_modules())
for index, element in enumerate(exec_seq):
if isinstance(element, str):
if element == 'SPLIT_NODE':
if element == "SPLIT_NODE":
continue
assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
assert (
element in named_modules
), f"Found invalid module name {element}, please check if you spell the module name correctly."
# get the layer spec based on the module ID
module = named_modules[element]
......@@ -198,11 +198,12 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
param_counts.append(layer_spec.count_params())
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
elif self._policy == "customized":
assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.'
assert (self._exec_seq
is not None), f"An explicit exec_seq must be defined by user in customized policy mode."
self.customized_parts = customized_partition(self._exec_seq)
assert len(self.customized_parts) == gpc.get_world_size(
ParallelMode.PIPELINE
), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}'
), f"World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partitions is {len(self.customized_parts)}"
parts = self.customized_parts[rank]
else:
raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].")
......@@ -241,7 +242,6 @@ class PipelinableModel(torch.nn.Module):
def forward(self, *input_tensor, **kwargs):
for module in self._module_list:
if id(module) in self._front_func_dict:
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
......
from typing import List, Dict, Tuple
import os
import threading
from typing import Dict, List, Tuple
from torch.distributed import rpc
import torch.distributed as dist
from torch.distributed import rpc
from colossalai.tensor import ProcessGroup
from colossalai.legacy.tensor import ProcessGroup
class PipelineProcessGroup:
......
from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine
from .utils import pytree_map
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
\ No newline at end of file
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
......@@ -12,9 +12,9 @@ from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.legacy.pipeline.pipeline_process_group import ppg
from colossalai.legacy.pipeline.rpc.utils import (
get_batch_lengths,
pyobj_map,
pytree_filter,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment