Unverified Commit 75abc75c authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[fx] fixed compatiblity issue with torch 1.10 (#1331)

parent 069d6fdc
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from packaging import version
import inspect import inspect
...@@ -233,10 +234,13 @@ def split_module( ...@@ -233,10 +234,13 @@ def split_module(
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes: for node in m.graph.nodes:
if node.op == 'placeholder': if node.op == 'placeholder':
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
type_expr=node.type, else:
default_value=default_value) default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy() base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again: # Do some things iterating over the partitions in topological order again:
......
...@@ -3,6 +3,7 @@ from ..registry import meta_patched_function ...@@ -3,6 +3,7 @@ from ..registry import meta_patched_function
@meta_patched_function.register(torch.matmul) @meta_patched_function.register(torch.matmul)
@meta_patched_function.register('matmul') # for built-in op @
def torch_matmul(input, other, *, out=None): def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx # copied from huggingface.utils.fx
d1 = input.dim() d1 = input.dim()
......
...@@ -96,6 +96,9 @@ class ColoTracer(Tracer): ...@@ -96,6 +96,9 @@ class ColoTracer(Tracer):
# fetch patched function # fetch patched function
if meta_patched_function.has(target): if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target) meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else: else:
meta_target = target meta_target = target
......
import torch import torch
import pytest import timm.models as tm
try:
import timm.models as tm
except:
pass
from timm_utils import split_model_and_compare_output from timm_utils import split_model_and_compare_output
......
import torch import torch
try: import torchvision
import torchvision.models as tm import torchvision.models as tm
except:
pass
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule from torch.fx import GraphModule
from packaging import version
import random import random
import numpy as np import numpy as np
import inspect import inspect
import pytest
MANUAL_SEED = 0 MANUAL_SEED = 0
random.seed(MANUAL_SEED) random.seed(MANUAL_SEED)
...@@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True ...@@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True
def test_torchvision_models(): def test_torchvision_models():
MODEL_LIST = [ MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5 tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5
] ]
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
tracer = ColoTracer() tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224) data = torch.rand(2, 3, 224, 224)
......
import torch import torch
import pytest import timm.models as tm
try:
import timm.models as tm
except:
pass
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from torch.fx import GraphModule from torch.fx import GraphModule
......
import torch import torch
import pytest import torchvision
try: import torchvision.models as tm
import torchvision.models as tm from packaging import version
except:
pass
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from torch.fx import GraphModule from torch.fx import GraphModule
...@@ -11,16 +9,22 @@ from torch.fx import GraphModule ...@@ -11,16 +9,22 @@ from torch.fx import GraphModule
def test_torchvision_models(): def test_torchvision_models():
MODEL_LIST = [ MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0 tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
] ]
RANDOMIZED_MODELS = [tm.efficientnet_b0]
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
RANDOMIZED_MODELS.append(tm.convnext_small)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
tracer = ColoTracer() tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224) data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
if model_cls in [tm.convnext_small, tm.efficientnet_b0]: if model_cls in RANDOMIZED_MODELS:
# remove the impact of randomicity # remove the impact of randomicity
model = model_cls(stochastic_depth_prob=0) model = model_cls(stochastic_depth_prob=0)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment