Unverified Commit e859380b authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[fx] support module with bias addition (#1780)

* [autoparallel] refactor tracer to fix bias addition issue

* [fx] support module with bias addition

* create bias_addition_module

* refactor file structure

* polish code

* fix unit test
parent f3f19a5c
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Linear)
......
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.LayerNorm)
......
import math
import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.AvgPool1d)
......
import torch
from ..registry import meta_patched_module
from typing import Optional
import torch
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.GRU)
@meta_patched_module.register(torch.nn.RNN)
......
......@@ -23,3 +23,5 @@ class PatchRegistry:
meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
......@@ -18,11 +18,10 @@ from torch.fx import Node, Tracer
from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods
from torch.fx.proxy import ParameterProxy, Proxy
from colossalai.fx.tracer.meta_patch import meta_patched_module
from ..proxy import ColoProxy
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .meta_patch import meta_patched_function, meta_patched_module
from .bias_addition_patch import module_to_func_dict
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
__all__ = ['ColoTracer']
......@@ -79,18 +78,126 @@ class ColoTracer(Tracer):
"""
Create a proxy for different kinds of operations.
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
if self.tracer_type == TracerType.DEFAULT:
# since meta_args is not given
# we just fall back to the original torch.fx.Tracer
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
return proxy
# if graph is traced for auto parallelism module, some extra node will be added during
# graph construction to deal with the compatability between bias addition and all reduce.
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
# to create node on computation graph
origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
# dispatch the arguments generator depending on the kind and target in origin arguments.
args_metas, _ = extract_meta(*args, **kwargs)
if kind == "call_function":
if bias_addition_function.has(target):
return bias_addition_function.get(target)(self, target, args, kwargs)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
return bias_addition_function.get(target.__name__)(self, target, args, kwargs)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_function.has(method):
return bias_addition_function.get(method)(self, target, args, kwargs)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute)
return handle.generate()
finally:
self._disable_module_getattr = False
# create nodes using patched arguments
proxy = super().create_proxy(*origin_arguments)
proxy: ColoProxy
meta_out = self._meta_data_computing(
kind,
target,
args,
kwargs,
)
proxy.meta_data = meta_out
return proxy
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
lambda node: ParameterProxy(self, node, n, attr_val))
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
module_qualified_name = self.path_of_module(m)
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
# which means customized modules are not leaf module by default
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
def proxy(self, node) -> Proxy:
"""
Returns a ColoProxy object.
"""
return self.proxy_cls(node, self)
def _configure_tracer_type(self, tracer_type: TracerType):
if tracer_type == TracerType.DEFAULT:
self.proxy_cls = Proxy
self.tracer_type = TracerType.DEFAULT
elif tracer_type == TracerType.META:
self.proxy_cls = ColoProxy
self.tracer_type = TracerType.META
else:
raise ValueError(f"Unrecognised tracer type {tracer_type}")
def _meta_data_computing(self, kind, target, args, kwargs):
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
proxy.meta_data = self.meta_args[target]
return proxy
meta_out = self.meta_args[target]
return meta_out
if target in self.orig_torch_tensor_methods:
# NOTE: tensor constructors in PyTorch define the `device` argument as
......@@ -154,75 +261,12 @@ class ColoTracer(Tracer):
finally:
self._disable_module_getattr = False
else:
return proxy
return None
if not isinstance(proxy, Proxy):
raise ValueError("Don't support composite output yet")
proxy.meta_data = meta_out
except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return proxy
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
lambda node: ParameterProxy(self, node, n, attr_val))
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
module_qualified_name = self.path_of_module(m)
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
# which means customized modules are not leaf module by default
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
def proxy(self, node) -> Proxy:
"""
Returns a ColoProxy object.
"""
return self.proxy_cls(node, self)
def _configure_tracer_type(self, tracer_type: TracerType):
if tracer_type == TracerType.DEFAULT:
self.proxy_cls = Proxy
self.tracer_type = TracerType.DEFAULT
elif tracer_type == TracerType.META:
self.proxy_cls = ColoProxy
self.tracer_type = TracerType.META
else:
raise ValueError(f"Unrecognised tracer type {tracer_type}")
return meta_out
def trace(self,
root: nn.Module,
......
from copy import deepcopy
from pickletools import optimize
import pytest
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
class ConvModel(nn.Module):
......@@ -67,7 +68,8 @@ def test_cost_graph():
for node in graph.nodes:
if node.op == 'output':
continue
all_node_pairs.append((node, node.next))
for child in node.users.keys():
all_node_pairs.append((node, child))
for node_pair in all_node_pairs:
assert node_pair in cost_graph.edge_costs
......@@ -75,14 +77,14 @@ def test_cost_graph():
# construct merged node pairs
merged_node_pairs = []
node_list = list(graph.nodes)
# add (x, conv) and (conv, output) into check node pairs
merged_node_pairs.append((node_list[0], node_list[2]))
merged_node_pairs.append((node_list[2], node_list[-1]))
# (conv1, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 246019.30000000002, (5, 0): 246019.30000000002, (6, 0): 123009.1, (7, 0): 123009.1, (8, 0): 123009.1, (9, 0): 123009.1, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): 246019.30000000002, (14, 0): 246019.30000000002}
# (x, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002}
# add (conv1_weight, conv2d), (conv1_bias, view), (conv2d, add), (view, add), (add, output), (x, conv2d) into check node pairs
merged_node_pairs.append((node_list[0], node_list[4]))
merged_node_pairs.append((node_list[2], node_list[4]))
merged_node_pairs.append((node_list[3], node_list[5]))
merged_node_pairs.append((node_list[5], node_list[6]))
merged_node_pairs.append((node_list[4], node_list[6]))
merged_node_pairs.append((node_list[6], node_list[-1]))
cost_graph.simplify_graph()
for node_pair in all_node_pairs:
if node_pair in merged_node_pairs:
assert node_pair in cost_graph.edge_costs
......
import pytest
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
class ConvModel(nn.Module):
......@@ -37,52 +39,22 @@ def test_conv_handler():
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# return add
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, conv, output]
nodes = [node for node in gm.graph.nodes]
# find the sharding strategies for the input node of the conv node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(nodes[1])
sharding_option = (None, 0, 1)
for first_sharding_index in sharding_option:
for second_sharding_index in sharding_option:
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
continue
if first_sharding_index is None:
first_dim_spec = _DimSpec([])
else:
first_dim_spec = _DimSpec([first_sharding_index])
if second_sharding_index is None:
second_dim_spec = _DimSpec([])
else:
second_dim_spec = _DimSpec([second_sharding_index])
replica_dim_spec = _DimSpec([])
sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape,
sharding_sequence=sharding_sequence)
strategy_name = str(sharding_spec.sharding_sequence)
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
strategies_vector_for_input.append(sharding_strategy)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
# generate conv strategy
strategies_vector = StrategiesVector(node=nodes[2])
conv_handler = ConvHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
conv_handler.register_strategy()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
conv_node = list(graph.nodes)[4]
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
strategy_name_list = [strategy.name for strategy in conv_node.strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
......
import pytest
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
class LinearModel(nn.Module):
......@@ -23,6 +25,7 @@ class LinearModel(nn.Module):
return x
@pytest.mark.skip('F.linear is not supported in deprecated handler')
def test_dot_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
......@@ -37,52 +40,23 @@ def test_dot_handler():
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# return add
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, linear, output]
nodes = [node for node in gm.graph.nodes]
# find the sharding strategies for the input node of the conv node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[1])
sharding_option = (None, 0, 1)
for first_sharding_index in sharding_option:
for second_sharding_index in sharding_option:
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
continue
if first_sharding_index is None:
first_dim_spec = _DimSpec([])
else:
first_dim_spec = _DimSpec([first_sharding_index])
if second_sharding_index is None:
second_dim_spec = _DimSpec([])
else:
second_dim_spec = _DimSpec([second_sharding_index])
sharding_sequence = [first_dim_spec, second_dim_spec]
sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape,
sharding_sequence=sharding_sequence)
strategy_name = str(sharding_spec.sharding_sequence)
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
strategies_vector_for_input.append(sharding_strategy)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
# generate dot strategy
strategies_vector = StrategiesVector(node=nodes[2])
dot_handler = DotHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
strategies_vector = dot_handler.register_strategy()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
linear_node = list(graph.nodes)[4]
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list = [strategy.name for strategy in strategies_vector]
strategy_name_list = [strategy.name for strategy in linear_node.strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
......
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
class ConvModel(nn.Module):
......@@ -33,7 +32,12 @@ def test_conv_handler():
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %flatten : [#users=1] = call_function[target=torch.flatten](args = (%add,), kwargs = {})
# return flatten
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
......@@ -44,10 +48,10 @@ def test_conv_handler():
strategies_constructor.build_strategies_and_cost()
strategy_map = strategies_constructor.strategy_map
conv_strategies = strategy_map[nodes[1]]
flatten_strategies = strategy_map[nodes[2]]
add_strategies = strategy_map[nodes[5]]
flatten_strategies = strategy_map[nodes[6]]
flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies]
for strategy in conv_strategies:
for strategy in add_strategies:
assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list
......
from copy import deepcopy
import pytest
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from copy import deepcopy
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
class ConvModel(nn.Module):
......@@ -40,9 +41,14 @@ def test_strategies_constructor():
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# return add
graph = tracer.trace(root=model, meta_args=input_sample)
print(graph)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
......@@ -63,12 +69,12 @@ def test_strategies_constructor():
# Third node is conv.
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.leaf_strategies[2]:
for strategy in strategies_constructor.leaf_strategies[4]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
assert strategies_constructor.leaf_strategies[3][0].name == 'Replica Output'
assert strategies_constructor.leaf_strategies[7][0].name == 'Replica Output'
# check strategy_map
......@@ -81,15 +87,15 @@ def test_strategies_constructor():
mul = nodes[1]
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
# Third node is conv.
conv = nodes[2]
# fifth node is conv.
conv = nodes[4]
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.strategy_map[conv]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
output = nodes[3]
output = nodes[-1]
assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
......
import transformers
import torch
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
@pytest.mark.skip('balance split v2 is not ready')
def test_single_sentence_albert():
MODEL_LIST = [
transformers.AlbertModel,
......
import transformers
import torch
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
@pytest.mark.skip('balance split v2 is not ready')
def test_single_sentence_bert():
MODEL_LIST = [
transformers.BertModel,
......
import transformers
import torch
import pytest
import torch
import transformers
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 64
......@@ -9,6 +9,7 @@ NUM_EPOCHS = 2
NUM_CHUNKS = 1
@pytest.mark.skip('balance split v2 is not ready')
def test_gpt():
MODEL_LIST = [
transformers.GPT2Model,
......
import pytest
import transformers
import torch
import transformers
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('balance split v2 is not ready')
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
......
import pytest
import transformers
import torch
import transformers
from hf_utils import split_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('balance split v2 is not ready')
def test_t5():
MODEL_LIST = [
transformers.T5Model,
......
import torch
import pytest
import timm.models as tm
import torch
from timm_utils import split_model_and_compare_output
import pytest
@pytest.mark.skip('balance split v2 is not ready')
def test_timm_models_without_control_flow():
MODEL_LIST = [
......@@ -24,6 +25,7 @@ def test_timm_models_without_control_flow():
split_model_and_compare_output(model, data)
@pytest.mark.skip('balance split v2 is not ready')
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True
......
import inspect
import random
import numpy as np
import pytest
import torch
import torchvision
import torchvision.models as tm
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule
from packaging import version
import random
import numpy as np
import inspect
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
......@@ -16,6 +19,7 @@ torch.manual_seed(MANUAL_SEED)
torch.backends.cudnn.deterministic = True
@pytest.mark.skip('balance split v2 is not ready')
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
......
import torch
from colossalai.fx import ColoGraphModule, ColoTracer
class LinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
def forward(self, x):
x = self.linear(x)
x = x * 2
return x
class ConvModel(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
bias=bias)
def forward(self, x):
x = self.conv(x)
x = x * 2
return x
def test_linear_module():
model = LinearModel(3, 6)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={'x': torch.rand(3, 3).to('meta')})
# def forward(self, x : torch.Tensor):
# linear_weight = self.linear.weight
# linear_bias = self.linear.bias
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)
for node in node_list:
if node.op == 'output':
continue
assert hasattr(node, '_meta_data')
weight_node = node_list[1]
bias_node = node_list[2]
linear_node = node_list[3]
add_node = node_list[4]
assert weight_node._meta_data.shape == (6, 3)
assert bias_node._meta_data.shape == (6,)
assert linear_node._meta_data.shape == (3, 6)
assert add_node._meta_data.shape == (3, 6)
def test_conv_module():
model = ConvModel(3, 6, 2)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
# def forward(self, x : torch.Tensor):
# conv_weight = self.conv.weight
# conv_bias = self.conv.bias
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
# add = conv2d + view; conv2d = view = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)
for node in node_list:
if node.op == 'output':
continue
assert hasattr(node, '_meta_data')
weight_node = node_list[1]
bias_node = node_list[2]
conv_node = node_list[3]
view_node = node_list[4]
add_node = node_list[5]
assert weight_node._meta_data.shape == (6, 3, 2, 2)
assert bias_node._meta_data.shape == (6,)
assert conv_node._meta_data.shape == (4, 6, 63, 63)
assert view_node._meta_data.shape == (1, 6, 1, 1)
assert add_node._meta_data.shape == (4, 6, 63, 63)
if __name__ == '__main__':
test_linear_module()
test_conv_module()
import torch
import pytest
import timm.models as tm
from colossalai.fx import ColoTracer
import torch
from torch.fx import GraphModule
import pytest
from colossalai.fx import ColoTracer
def trace_and_compare(model_cls, tracer, data, meta_args=None):
......@@ -22,7 +23,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
with torch.no_grad():
fx_out = gm(data)
non_fx_out = model(data)
# compare output
if isinstance(fx_out, tuple):
# some models produce tuple as output
......@@ -30,7 +31,8 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}'
else:
assert torch.allclose(
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
fx_out, non_fx_out,
atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
def test_timm_models_without_control_flow():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment