"vscode:/vscode.git/clone" did not exist on "2edbef13cc2f08e3d74ea72a68d2299f3e7cdbb7"
Unverified Commit 393f5940 authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions...

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)

* [fx] move meta registration

* [fx] fix tests.

* [fx] fix test.

* [fx] fix.

* [meta] refactor meta registration.py.

* [fx] add compatibility descriptions.

* [fx] polish import.

* [fx] add a decorator.

* [fx] fix tests.

* [fx] remove print.

* [fx] edit raise error.

* [fx] edit raise error.

* [fx] add type hint.

* [fx] fix import in experimental.

* [rpc] remove color debug.

* [meta] fix naming.
parent e8d8eda5
import inspect
import math
import threading import threading
from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from functools import partial from functools import partial
from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Tuple
import math
import inspect
import torch import torch
from torch import nn
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail, from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
pytree_map, pytree_filter, get_real_args_kwargs, use_color_debug) split_batch, tensor_shape_list, type_detail)
from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
class Phase(Enum): class Phase(Enum):
...@@ -195,7 +191,6 @@ class WorkerBase(ABC): ...@@ -195,7 +191,6 @@ class WorkerBase(ABC):
if isinstance(output, Future): if isinstance(output, Future):
output = output.wait() output = output.wait()
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1 output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released # all consumers have been satisfied, the work_item can be released
...@@ -250,9 +245,6 @@ class WorkerBase(ABC): ...@@ -250,9 +245,6 @@ class WorkerBase(ABC):
self.num_microbatches, forward_only) self.num_microbatches, forward_only)
with self.work_list_condition_lock: with self.work_list_condition_lock:
self.work_list[key] = work_item self.work_list[key] = work_item
if use_color_debug:
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
# just for last pp_rank # just for last pp_rank
...@@ -273,9 +265,6 @@ class WorkerBase(ABC): ...@@ -273,9 +265,6 @@ class WorkerBase(ABC):
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False) self.num_microbatches, False)
if use_color_debug:
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item self.work_list[key] = work_item
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
...@@ -297,23 +286,14 @@ class WorkerBase(ABC): ...@@ -297,23 +286,14 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
'data dispatch', 'magenta')
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only) microbatch_id, None, self.num_microbatches, forward_only)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list # add work_item to work_list
with self.work_list_condition_lock: with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list assert key not in self.work_list
self.work_list[key] = work_item_from_producer self.work_list[key] = work_item_from_producer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
def subscribe_consumer(self, microbatch_id: int): def subscribe_consumer(self, microbatch_id: int):
...@@ -328,10 +308,6 @@ class WorkerBase(ABC): ...@@ -328,10 +308,6 @@ class WorkerBase(ABC):
subscribe_backward_futures: List[Future] = [None] * consumer_num subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device() output = self._get_future_by_device()
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta')
for i in range(consumer_num): for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i] consumer_stage_id = self.consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
...@@ -342,17 +318,11 @@ class WorkerBase(ABC): ...@@ -342,17 +318,11 @@ class WorkerBase(ABC):
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False) microbatch_id, None, self.num_microbatches, False)
# color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list # add work_item to work_list
with self.work_list_condition_lock: with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.BACKWARD) key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list assert key not in self.work_list
self.work_list[key] = work_item_from_consumer self.work_list[key] = work_item_from_consumer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
def _get_producer_consumer(self) -> None: def _get_producer_consumer(self) -> None:
...@@ -406,11 +376,6 @@ class WorkerBase(ABC): ...@@ -406,11 +376,6 @@ class WorkerBase(ABC):
is_first_stage = self.is_first_stage() is_first_stage = self.is_first_stage()
is_last_stage = self.is_last_stage() is_last_stage = self.is_last_stage()
# if self.pp_rank == 0:
# print(
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
# )
if phase == Phase.FORWARD: if phase == Phase.FORWARD:
# remind its consumer to get data before forward # remind its consumer to get data before forward
if not is_last_stage: if not is_last_stage:
...@@ -470,8 +435,6 @@ class WorkerBase(ABC): ...@@ -470,8 +435,6 @@ class WorkerBase(ABC):
else: else:
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion: if is_last_stage and self.criterion:
with self.label_lock: with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
...@@ -539,10 +502,6 @@ class WorkerBase(ABC): ...@@ -539,10 +502,6 @@ class WorkerBase(ABC):
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
# for input_node in stage_input_args:
# if isinstance(input_node, torch.Tensor):
# consume_result.append(input_node.grad)
else: else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
...@@ -593,11 +552,6 @@ class WorkerBase(ABC): ...@@ -593,11 +552,6 @@ class WorkerBase(ABC):
with self.work_list_condition_lock: with self.work_list_condition_lock:
work_item = self.work_list.pop(work_item_key) work_item = self.work_list.pop(work_item_key)
if use_color_debug:
color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green')
with self.output_list_condition_lock: with self.output_list_condition_lock:
# assert work_item_key not in self.output_list # assert work_item_key not in self.output_list
self.output_list[work_item_key] = work_item self.output_list[work_item_key] = work_item
...@@ -605,11 +559,6 @@ class WorkerBase(ABC): ...@@ -605,11 +559,6 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item) consume_result = self._consume_work_item_by_phase(work_item)
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green')
work_item.output.set_result(consume_result) work_item.output.set_result(consume_result)
# if is last step in one batch reset context and do step # if is last step in one batch reset context and do step
......
from typing import List, Callable, Dict
import threading import threading
from typing import Callable, Dict, List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
# Implementation of different Pipeline schedule # Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage # <strategy>Worker defines the worker for each stage
......
from typing import List, Any, Tuple, Dict, Callable, Type, Union import argparse
import os import os
import warnings import warnings
import argparse from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch import torch
import torch.multiprocessing as mp
from torch.futures import Future
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from torch._C._distributed_rpc import _is_current_rpc_agent_set import torch.multiprocessing as mp
from colorama import Back, Style
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from torch._C._distributed_rpc import _is_current_rpc_agent_set
# config for debug and test from torch.futures import Future
use_color_debug = False
def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
......
import copy import copy
import colossalai
import pytest
import torch import torch
import torch.fx
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
import torch.fx from colossalai.core import global_context as gpc
import colossalai
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.algorithms import solver_rotor from colossalai.fx.passes.algorithms import solver_rotor
from colossalai.fx.passes.algorithms.operation import Sequence from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.core import global_context as gpc from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.utils import free_port
import pytest
from colossalai import META_COMPATIBILITY if is_compatible_with_meta():
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
try: try:
...@@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0): ...@@ -34,7 +36,7 @@ def _run_C_solver_consistency_test(rank=0):
graph = tracer.trace(model, meta_args={"x": data}) graph = tracer.trace(model, meta_args={"x": data})
graph.set_codegen(ActivationCheckpointCodeGen()) graph.set_codegen(ActivationCheckpointCodeGen())
gm = ColoGraphModule(model, graph, model.__class__.__name__) gm = ColoGraphModule(model, graph, model.__class__.__name__)
if META_COMPATIBILITY: if is_compatible_with_meta():
data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device) data_meta = MetaTensor(data, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(data_meta) MetaInfoProp(gm).run(data_meta)
......
from typing import Callable
import copy import copy
import re import re
from typing import Callable
import colossalai
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from torch.fx import GraphModule from colossalai.core import global_context as gpc
import colossalai
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from torch.fx import GraphModule
import pytest
from colossalai import META_COMPATIBILITY if is_compatible_with_meta():
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
try: try:
...@@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule): ...@@ -54,8 +56,9 @@ def _is_graph_linearized(gm: GraphModule):
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]): model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
data = torch.rand(2, 3, 32, 32) m.cuda()
label = torch.rand(2, 5) data = torch.rand(2, 3, 32, 32).cuda()
label = torch.rand(2, 5).cuda()
loss = criterion(m(data), label) loss = criterion(m(data), label)
loss.backward() loss.backward()
loss = criterion(gm(data), label) loss = criterion(gm(data), label)
...@@ -77,7 +80,7 @@ def _run_ckpt_solver(rank): ...@@ -77,7 +80,7 @@ def _run_ckpt_solver(rank):
m = model_cls(num_classes=5) m = model_cls(num_classes=5)
graph = tracer.trace(root=m) graph = tracer.trace(root=m)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__) gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda')) MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
codegen = ActivationCheckpointCodeGen() codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen) gm.graph.set_codegen(codegen)
if solver == solver_rotor: if solver == solver_rotor:
......
from colossalai.fx.passes.meta_info_prop import MetaInfoProp import pytest
import torch import torch
import torchvision.models as tm import torchvision.models as tm
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import solver_rotor, linearize from colossalai.fx.passes.algorithms import linearize, solver_rotor
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
import pytest from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai import META_COMPATIBILITY
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
try: try:
......
import torch
import torch.nn as nn
import colossalai import colossalai
import colossalai.nn as col_nn import colossalai.nn as col_nn
from torch.fx import symbolic_trace import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
from colossalai.fx.passes.utils import get_comm_size from colossalai.fx.passes.utils import get_comm_size
from colossalai import META_COMPATIBILITY from torch.fx import symbolic_trace
import pytest
is_compatible = is_compatible_with_meta()
if is_compatible:
from colossalai.fx.profiler import MetaTensor
MODEL_DIM = 16 MODEL_DIM = 16
BATCH_SIZE = 8 BATCH_SIZE = 8
...@@ -31,12 +35,12 @@ class MLP(torch.nn.Module): ...@@ -31,12 +35,12 @@ class MLP(torch.nn.Module):
return x return x
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_comm_size_compute(): def test_comm_size_compute():
from colossalai.fx.profiler import MetaTensor
model = MLP(MODEL_DIM) model = MLP(MODEL_DIM)
input_sample = MetaTensor(torch.rand(BATCH_SIZE, MODEL_DIM, device='meta'), fake_device='cpu') input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
gm = symbolic_trace(model) gm = symbolic_trace(model)
if is_compatible:
input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
MetaInfoProp(gm).run(input_sample) MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
split_model, split_submodules = split_with_split_nodes_pass(annotated_model) split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
......
from typing import Any, Callable, Union from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai import META_COMPATIBILITY
import pytest import pytest
import torch
import torch.nn as nn
from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
aten = torch.ops.aten aten = torch.ops.aten
...@@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac ...@@ -71,7 +70,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
compare_all(x.grad, meta_x.grad) compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_meta_aten(): def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items(): for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v: for f, x in v:
......
import torchvision.models as tm import pytest
import timm.models as tmm import timm.models as tmm
import torch import torch
from colossalai import META_COMPATIBILITY import torchvision.models as tm
import pytest from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
tm_models = [ tm_models = [
...@@ -27,7 +27,7 @@ tmm_models = [ ...@@ -27,7 +27,7 @@ tmm_models = [
] ]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models(): def test_torchvision_models():
for m in tm_models: for m in tm_models:
model = m() model = m()
...@@ -35,7 +35,7 @@ def test_torchvision_models(): ...@@ -35,7 +35,7 @@ def test_torchvision_models():
model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward()
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models(): def test_timm_models():
for m in tmm_models: for m in tmm_models:
model = m() model = m()
......
import torchvision.models as tm import pytest
import timm.models as tmm import timm.models as tmm
import torch import torch
from colossalai import META_COMPATIBILITY import torchvision.models as tm
import pytest from colossalai.fx._compatibility import is_compatible_with_meta
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx import meta_trace from colossalai.fx import meta_trace
tm_models = [ tm_models = [
...@@ -27,7 +27,7 @@ tmm_models = [ ...@@ -27,7 +27,7 @@ tmm_models = [
] ]
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_torchvision_models_trace(): def test_torchvision_models_trace():
for m in tm_models: for m in tm_models:
model = m() model = m()
...@@ -35,7 +35,7 @@ def test_torchvision_models_trace(): ...@@ -35,7 +35,7 @@ def test_torchvision_models_trace():
graph = meta_trace(model, torch.device('cpu'), data) graph = meta_trace(model, torch.device('cpu'), data)
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0')
def test_timm_models_trace(): def test_timm_models_trace():
for m in tmm_models: for m in tmm_models:
model = m() model = m()
......
import torch import torch
from torch.fx import symbolic_trace from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai import META_COMPATIBILITY
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from torch.fx import symbolic_trace
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
BATCH_SIZE = 2 BATCH_SIZE = 2
DIM_IN = 4 DIM_IN = 4
...@@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): ...@@ -18,8 +21,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
def test_meta_info_prop(): def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT) model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
if META_COMPATIBILITY: if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
input_sample = MetaTensor(input_sample, fake_device='cpu') input_sample = MetaTensor(input_sample, fake_device='cpu')
orig_output = model(input_sample) orig_output = model(input_sample)
gm = symbolic_trace(model) gm = symbolic_trace(model)
......
import os
import argparse import argparse
import os
import warnings import warnings
import torch import torch
from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch.optim import SGD, Adam, RMSprop, Optimizer
from torch._C._distributed_rpc import _is_current_rpc_agent_set
import torch.distributed as dist import torch.distributed as dist
from colorama import Back, Style import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.logging import disable_existing_loggers
from colossalai import launch from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
from torch import nn
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop
rpc_is_initialized = _is_current_rpc_agent_set rpc_is_initialized = _is_current_rpc_agent_set
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment