Unverified Commit 393f5940 authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions...

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
parent e8d8eda5
import inspect
import math
import threading
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from functools import partial
from abc import ABC, abstractmethod
import math
import inspect
from typing import Any, Callable, Dict, List, Tuple
import torch
from torch import nn
import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug)
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
split_batch, tensor_shape_list, type_detail)
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
class Phase(Enum):
......@@ -195,7 +191,6 @@ class WorkerBase(ABC):
if isinstance(output, Future):
output = output.wait()
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released
......@@ -250,9 +245,6 @@ class WorkerBase(ABC):
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
if use_color_debug:
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# just for last pp_rank
......@@ -273,9 +265,6 @@ class WorkerBase(ABC):
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False)
if use_color_debug:
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
......@@ -297,23 +286,14 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
'data dispatch', 'magenta')
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_producer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int):
......@@ -328,10 +308,6 @@ class WorkerBase(ABC):
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta')
for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
......@@ -342,17 +318,11 @@ class WorkerBase(ABC):
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_consumer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
def _get_producer_consumer(self) -> None:
......@@ -406,11 +376,6 @@ class WorkerBase(ABC):
is_first_stage = self.is_first_stage()
is_last_stage = self.is_last_stage()
# if self.pp_rank == 0:
# print(
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
# )
if phase == Phase.FORWARD:
# remind its consumer to get data before forward
if not is_last_stage:
......@@ -470,8 +435,6 @@ class WorkerBase(ABC):
else:
consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
......@@ -539,10 +502,6 @@ class WorkerBase(ABC):
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
# for input_node in stage_input_args:
# if isinstance(input_node, torch.Tensor):
# consume_result.append(input_node.grad)
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
......@@ -593,11 +552,6 @@ class WorkerBase(ABC):
with self.work_list_condition_lock:
work_item = self.work_list.pop(work_item_key)
if use_color_debug:
color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green')
with self.output_list_condition_lock:
# assert work_item_key not in self.output_list
self.output_list[work_item_key] = work_item
......@@ -605,11 +559,6 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item)
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green')
work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step
......
from typing import List, Callable, Dict
import threading
from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
......
from typing import List, Any, Tuple, Dict, Callable, Type, Union
import argparse
import os
import warnings
import argparse
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch
import torch.multiprocessing as mp
from torch.futures import Future
import torch.distributed.rpc as rpc
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colorama import Back, Style
import torch.multiprocessing as mp
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
# config for debug and test
use_color_debug = False
def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.futures import Future
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
......
import copy
import colossalai
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm
import torch.fx
import colossalai
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.core import global_context as gpc
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.algorithms import solver_rotor
from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.core import global_context as gpc
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:
......@@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0):
graph = tracer.trace(model, meta_args={"x": data})
graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__)
if META_COMPATIBILITY:
if is_compatible_with_meta():
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data_meta)
......
from typing import Callable
import copy
import re
from typing import Callable
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from torch.fx import GraphModule
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:
......@@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule):
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss()
data = torch.rand(2, 3, 32, 32)
label = torch.rand(2, 5)
m.cuda()
data = torch.rand(2, 3, 32, 32).cuda()
label = torch.rand(2, 5).cuda()
loss = criterion(m(data), label)
loss.backward()
loss = criterion(gm(data), label)
......@@ -77,7 +80,7 @@ def _run_ckpt_solver(rank):
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
if solver == solver_rotor:
......
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
import pytest
import torch
import torchvision.models as tm
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import solver_rotor, linearize
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
import pytest
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.passes.algorithms import linearize, solver_rotor
from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
try:
......
import torch
import torch.nn as nn
import colossalai
import colossalai.nn as col_nn
from torch.fx import symbolic_trace
import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
from colossalai.fx.passes.utils import get_comm_size
from colossalai import META_COMPATIBILITY
import pytest
from torch.fx import symbolic_trace
is_compatible = is_compatible_with_meta()
if is_compatible:
from colossalai.fx.profiler import MetaTensor
MODEL_DIM = 16
BATCH_SIZE = 8
......@@ -31,12 +35,12 @@ class MLP(torch.nn.Module):
return x
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_comm_size_compute():
from colossalai.fx.profiler import MetaTensor
model = MLP(MODEL_DIM)
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu')
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
gm = symbolic_trace(model)
if is_compatible:
input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
......
from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai import META_COMPATIBILITY
import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
aten = torch.ops.aten
......@@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v:
......
import torchvision.models as tm
import pytest
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
import pytest
import torchvision.models as tm
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
tm_models = [
......@@ -27,7 +27,7 @@ tmm_models = [
]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models():
for m in tm_models:
model = m()
......@@ -35,7 +35,7 @@ def test_torchvision_models():
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models():
for m in tmm_models:
model = m()
......
import torchvision.models as tm
import pytest
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
import pytest
import torchvision.models as tm
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY:
if is_compatible_with_meta():
from colossalai.fx import meta_trace
tm_models = [
......@@ -27,7 +27,7 @@ tmm_models = [
]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models_trace():
for m in tm_models:
model = m()
......@@ -35,7 +35,7 @@ def test_torchvision_models_trace():
graph = meta_trace(model, torch.device('cpu'), data)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models_trace():
for m in tmm_models:
model = m()
......
import torch
from torch.fx import symbolic_trace
from colossalai import META_COMPATIBILITY
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from torch.fx import symbolic_trace
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
BATCH_SIZE = 2
DIM_IN = 4
......@@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
if is_compatible_with_meta():
input_sample = MetaTensor(input_sample, fake_device='cpu')
orig_output = model(input_sample)
gm = symbolic_trace(model)
......
import os
import argparse
import os
import warnings
import torch
from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch.optim import SGD, Adam, RMSprop, Optimizer
from torch._C._distributed_rpc import _is_current_rpc_agent_set
import torch.distributed as dist
from colorama import Back, Style
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.logging import disable_existing_loggers
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
from torch import nn
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop
rpc_is_initialized = _is_current_rpc_agent_set
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment