Unverified Commit fee2af86 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] adapt autoparallel with new analyzer (#3261)

* [autoparallel] adapt autoparallel with new analyzer

* fix all node handler tests

* polish

* polish
parent e78a1e94
...@@ -5,6 +5,9 @@ import torch ...@@ -5,6 +5,9 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
...@@ -13,7 +16,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -13,7 +16,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -49,9 +51,11 @@ def check_linear_module_handler(rank, bias, input_shape, world_size, port): ...@@ -49,9 +51,11 @@ def check_linear_module_handler(rank, bias, input_shape, world_size, port):
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')}) meta_args = {"input": torch.rand(input_shape).cuda()}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[1] linear_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)
...@@ -196,13 +200,12 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port): ...@@ -196,13 +200,12 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args = {'input': torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"input": torch.rand(input_shape).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if bias: if bias:
linear_func_node = list(graph.nodes)[3] linear_func_node = list(graph.nodes)[3]
else: else:
......
...@@ -2,6 +2,9 @@ import pytest ...@@ -2,6 +2,9 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import ( from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
MatMulHandler, MatMulHandler,
MatMulType, MatMulType,
...@@ -15,7 +18,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -15,7 +18,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.utils import parameterize from colossalai.testing.utils import parameterize
...@@ -57,9 +59,11 @@ def test_matmul_node_handler(tensor_shapes): ...@@ -57,9 +59,11 @@ def test_matmul_node_handler(tensor_shapes):
model = MatMulModule() model = MatMulModule()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')}) meta_args = {"x1": x1.to('meta'), 'x2': x2.to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
print(graph) print(graph)
...@@ -124,7 +128,6 @@ def test_matmul_node_handler(tensor_shapes): ...@@ -124,7 +128,6 @@ def test_matmul_node_handler(tensor_shapes):
input_sharding_spec = strategy.get_sharding_spec_by_name('x1') input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2') other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
if matmul_type == MatMulType.DOT: if matmul_type == MatMulType.DOT:
# dot product will produce a scaler # dot product will produce a scaler
# results should fulfill: # results should fulfill:
...@@ -159,6 +162,9 @@ def test_matmul_node_handler(tensor_shapes): ...@@ -159,6 +162,9 @@ def test_matmul_node_handler(tensor_shapes):
if len(other_shape) > 1: if len(other_shape) > 1:
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if len(input_shape) > 1: if len(input_shape) > 1:
if len(other_shape) == 1:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-1]
else:
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2] assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
if len(other_shape) > 2: if len(other_shape) > 2:
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]
......
...@@ -2,10 +2,12 @@ import pytest ...@@ -2,10 +2,12 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
...@@ -13,14 +15,16 @@ from colossalai.testing.pytest_wrapper import run_on_environment_flag ...@@ -13,14 +15,16 @@ from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler(): def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0 # return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -18,19 +21,20 @@ class OutputModel(nn.Module): ...@@ -18,19 +21,20 @@ class OutputModel(nn.Module):
return x, y return x, y
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('output_option', ['distributed', 'replicated']) @parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_output_handler(output_option): def test_output_handler(output_option):
model = OutputModel() model = OutputModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %x : torch.Tensor [#users=2] = placeholder[target=x] # %x : torch.Tensor [#users=2] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# return (x, mul) # return (x, mul)
graph = tracer.trace(model, meta_args={ meta_args = {'x': torch.rand(4, 4, 64, 64).to('meta')}
"x": torch.rand(4, 4, 64, 64).to('meta'), graph = tracer.trace(model, meta_args=meta_args)
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
...@@ -5,12 +5,14 @@ import torch ...@@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -88,7 +90,7 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, ...@@ -88,7 +90,7 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
input_args=[input, other], input_args=[input, other],
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvReshapeModel': if model_cls.__name__ == 'ConvReshapeModel':
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
...@@ -96,11 +98,11 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, ...@@ -96,11 +98,11 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {})
# return permute # return permute
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 8, 66, 66).to('meta'),
"input": torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
if model_cls.__name__ == 'LinearReshapeModel': if model_cls.__name__ == 'LinearReshapeModel':
# graph(): # graph():
...@@ -109,13 +111,14 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, ...@@ -109,13 +111,14 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return permute # return permute
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), 'other': torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
reshape_node = list(graph.nodes)[3] reshape_node = list(graph.nodes)[3]
......
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -17,18 +20,21 @@ class PlaceholderModel(nn.Module): ...@@ -17,18 +20,21 @@ class PlaceholderModel(nn.Module):
return input return input
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('placeholder_option', ['distributed', 'replicated']) @parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_placeholder_handler(placeholder_option): def test_placeholder_handler(placeholder_option):
model = PlaceholderModel() model = PlaceholderModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# return input_1 # return input_1
graph = tracer.trace(model, meta_args={ meta_args = {
"input": torch.rand(4, 4, 64, 64).to('meta'), "input": torch.rand(4, 4, 64, 64).to('meta'),
}) }
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
from functools import partial
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.options import ShardOption from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
class LinearModel(nn.Module): class LinearModel(nn.Module):
...@@ -30,13 +28,11 @@ def check_shard_option(shard_option): ...@@ -30,13 +28,11 @@ def check_shard_option(shard_option):
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args = {'input': torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"input": torch.rand(4, 4, 4, 16).to('meta'),
'others': torch.rand(32, 16).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_func_node = list(graph.nodes)[2] linear_func_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_func_node) strategies_vector = StrategiesVector(linear_func_node)
......
...@@ -6,11 +6,13 @@ import torch.multiprocessing as mp ...@@ -6,11 +6,13 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -54,7 +56,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port): ...@@ -54,7 +56,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
input_args=[input, other], input_args=[input, other],
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
...@@ -62,13 +64,14 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port): ...@@ -62,13 +64,14 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
# return split # return split
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), 'other': torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
split_node = list(graph.nodes)[3] split_node = list(graph.nodes)[3]
......
...@@ -5,12 +5,14 @@ import torch ...@@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -76,7 +78,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port ...@@ -76,7 +78,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
input_args=[input, other], input_args=[input, other],
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvSplitModel': if model_cls.__name__ == 'ConvSplitModel':
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
...@@ -84,11 +86,11 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port ...@@ -84,11 +86,11 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {}) # %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {})
# return split # return split
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 8, 66, 66).to('meta'),
"input": torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
if model_cls.__name__ == 'LinearSplitModel': if model_cls.__name__ == 'LinearSplitModel':
# graph(): # graph():
...@@ -97,13 +99,14 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port ...@@ -97,13 +99,14 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
# return split # return split
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), 'other': torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
split_node = list(graph.nodes)[3] split_node = list(graph.nodes)[3]
......
...@@ -5,12 +5,13 @@ import torch ...@@ -5,12 +5,13 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -58,7 +59,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port): ...@@ -58,7 +59,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
...@@ -66,12 +67,13 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port): ...@@ -66,12 +67,13 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {})
# return sum_1 # return sum_1
graph = tracer.trace(model, meta_args = {
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'), "input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), "other": torch.rand(64, 32).to('meta'),
}) }
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
sum_node = list(graph.nodes)[3] sum_node = list(graph.nodes)[3]
...@@ -116,107 +118,107 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port): ...@@ -116,107 +118,107 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
# check strategy name # check strategy name
if sum_dims == (0, 2) and keepdim == False: if sum_dims == (0, 2) and keepdim == False:
assert '[R, R, R, S1] -> [R, S1]_0' in strategy_name_list assert '[R, R, R, R] -> [R, R]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [S0, S1]_1' in strategy_name_list assert '[R, S01, R, R] -> [S01, R]_1' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_2' in strategy_name_list assert '[R, R, R, R] -> [R, R]_2' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_3' in strategy_name_list assert '[R, R, R, R] -> [R, R]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [S1, S0]_4' in strategy_name_list assert '[R, R, R, S01] -> [R, S01]_4' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_5' in strategy_name_list assert '[R, R, R, S1] -> [R, S1]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_6' in strategy_name_list assert '[R, R, R, S0] -> [R, S0]_6' in strategy_name_list
assert '[R, S0, R, R] -> [S0, R]_7' in strategy_name_list assert '[R, R, R, R] -> [R, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_9' in strategy_name_list assert '[R, R, R, S0] -> [R, S0]_9' in strategy_name_list
assert '[R, S1, R, R] -> [S1, R]_10' in strategy_name_list assert '[R, R, R, S1] -> [R, S1]_10' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_11' in strategy_name_list assert '[R, R, R, S1] -> [R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_12' in strategy_name_list assert '[R, S0, R, S1] -> [S0, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_13' in strategy_name_list assert '[R, R, R, S1] -> [R, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_14' in strategy_name_list assert '[R, R, R, S0] -> [R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_15' in strategy_name_list assert '[R, S1, R, S0] -> [S1, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1]_17' in strategy_name_list assert '[R, R, R, R] -> [R, R]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_18' in strategy_name_list assert '[R, S0, R, R] -> [S0, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [S01, R]_19' in strategy_name_list assert '[R, R, R, R] -> [R, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_21' in strategy_name_list assert '[R, S1, R, R] -> [S1, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01]_22' in strategy_name_list assert '[R, R, R, R] -> [R, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list
if sum_dims == (0, 2) and keepdim == True: if sum_dims == (0, 2) and keepdim == True:
assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_2' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_2' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_5' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_11' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list
if sum_dims == 1 and keepdim == False: if sum_dims == 1 and keepdim == False:
assert '[S0, R, R, S1] -> [S0, R, S1]_0' in strategy_name_list assert '[S01, R, R, R] -> [S01, R, R]_0' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_1' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, S1]_2' in strategy_name_list assert '[R, R, S01, R] -> [R, S01, R]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, S0]_3' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_4' in strategy_name_list assert '[R, R, R, S01] -> [R, R, S01]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, S0]_5' in strategy_name_list assert '[R, R, R, S1] -> [R, R, S1]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R]_6' in strategy_name_list assert '[R, R, R, S0] -> [R, R, S0]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R]_8' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R]_9' in strategy_name_list assert '[R, R, R, S0] -> [R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_10' in strategy_name_list assert '[R, R, R, S1] -> [R, R, S1]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R]_11' in strategy_name_list assert '[S0, R, R, S1] -> [S0, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_13' in strategy_name_list assert '[R, R, S0, S1] -> [R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_14' in strategy_name_list assert '[S1, R, R, S0] -> [S1, R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_15' in strategy_name_list assert '[R, R, R, S0] -> [R, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, S0]_16' in strategy_name_list assert '[R, R, S1, S0] -> [R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, S1]_17' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R]_18' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_19' in strategy_name_list assert '[R, R, S0, R] -> [R, S0, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R]_20' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, S01]_22' in strategy_name_list assert '[R, R, S1, R] -> [R, S1, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list
if sum_dims == 1 and keepdim == True: if sum_dims == 1 and keepdim == True:
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
...@@ -22,7 +24,7 @@ class TensorConstructorModel(nn.Module): ...@@ -22,7 +24,7 @@ class TensorConstructorModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
def test_where_handler(): def test_where_handler():
model = TensorConstructorModel() model = TensorConstructorModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %x : torch.Tensor [#users=2] = placeholder[target=x] # %x : torch.Tensor [#users=2] = placeholder[target=x]
# %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {}) # %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {})
...@@ -30,10 +32,10 @@ def test_where_handler(): ...@@ -30,10 +32,10 @@ def test_where_handler():
# %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {}) # %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {})
# return add # return add
graph = tracer.trace(model, meta_args={ meta_args = {'x': torch.rand(10).to('meta')}
"x": torch.rand(10).to('meta'), graph = tracer.trace(model, meta_args=meta_args)
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
...@@ -25,19 +26,20 @@ class ReLuModel(nn.Module): ...@@ -25,19 +26,20 @@ class ReLuModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
def test_elementwise_handler(): def test_elementwise_handler():
model = ReLuModel() model = ReLuModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other] # %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {}) # %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {})
# return act # return act
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(4, 4, 64, 64).to('meta'),
"input": torch.rand(4, 4, 64, 64).to('meta'), 'other': torch.rand(16, 4, 3, 3).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
...@@ -69,13 +71,13 @@ def test_elementwise_handler(): ...@@ -69,13 +71,13 @@ def test_elementwise_handler():
assert mapping['input'].name == "conv2d" assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].name == "act" assert mapping['output'].name == "act"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 4, 62, 62]) assert mapping['output'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
......
...@@ -5,12 +5,14 @@ import torch ...@@ -5,12 +5,14 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -74,7 +76,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): ...@@ -74,7 +76,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
input_args=[input, other], input_args=[input, other],
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
if model_cls.__name__ == 'ConvViewModel': if model_cls.__name__ == 'ConvViewModel':
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
...@@ -82,11 +84,8 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): ...@@ -82,11 +84,8 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view # return view
graph = tracer.trace(model, meta_args = {'input': torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'),
})
if model_cls.__name__ == 'LinearViewModel': if model_cls.__name__ == 'LinearViewModel':
# graph(): # graph():
...@@ -95,13 +94,14 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): ...@@ -95,13 +94,14 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return view # return view
graph = tracer.trace(model, meta_args = {
meta_args={ 'input': torch.rand(8, 16, 64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), 'other': torch.rand(64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
previous_mod_node = list(graph.nodes)[2] previous_mod_node = list(graph.nodes)[2]
view_node = list(graph.nodes)[3] view_node = list(graph.nodes)[3]
......
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \ from colossalai._analyzer.fx.graph_module import ColoGraphModule
WhereHandler from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ConvModel(nn.Module): class ConvModel(nn.Module):
...@@ -19,22 +20,24 @@ class ConvModel(nn.Module): ...@@ -19,22 +20,24 @@ class ConvModel(nn.Module):
return output return output
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_where_handler(): def test_where_handler():
model = ConvModel() model = ConvModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %condition : torch.Tensor [#users=1] = placeholder[target=condition] # %condition : torch.Tensor [#users=1] = placeholder[target=condition]
# %x : torch.Tensor [#users=1] = placeholder[target=x] # %x : torch.Tensor [#users=1] = placeholder[target=x]
# %y : torch.Tensor [#users=1] = placeholder[target=y] # %y : torch.Tensor [#users=1] = placeholder[target=y]
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
# return where # return where
graph = tracer.trace(model, meta_args = {
meta_args={ 'condition': torch.rand(4, 4, 64, 64).to('meta'),
"condition": torch.rand(4, 4, 64, 64).to('meta'), 'x': torch.rand(4, 1, 64, 64).to('meta'),
"x": torch.rand(4, 1, 64, 64).to('meta'), 'y': torch.rand(1, 4, 64, 64).to('meta')
"y": torch.rand(1, 4, 64, 64).to('meta') }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
...@@ -4,6 +4,9 @@ from typing import Dict, List ...@@ -4,6 +4,9 @@ from typing import Dict, List
import torch import torch
from torch.fx import GraphModule from torch.fx import GraphModule
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.options import SolverOptions
...@@ -11,7 +14,6 @@ from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor ...@@ -11,7 +14,6 @@ from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import to_global from colossalai.tensor.shape_consistency import to_global
from colossalai.testing.comparison import assert_close from colossalai.testing.comparison import assert_close
...@@ -79,14 +81,16 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, ...@@ -79,14 +81,16 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs,
grad_to_shard_dict) grad_to_shard_dict)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
input_sample = {} input_sample = {}
for input_arg, meta_arg_name in zip(input_args, meta_arg_names): for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to('meta')
for meta_kwarg_name, input_kwarg in input_kwargs.items(): for meta_kwarg_name, input_kwarg in input_kwargs.items():
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to('meta')
graph = tracer.trace(root=model_to_shard, meta_args=input_sample) graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
shape_prop_pass(gm, *input_sample.values())
solver_options = SolverOptions() solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()
......
import pytest import pytest
import torch import torch
import transformers import transformers
from topo_utils import split_model_and_get_DAG, check_topo, MLP from topo_utils import MLP, check_topo, split_model_and_get_DAG
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_opt(): def test_opt():
MODEL_LIST = [ MODEL_LIST = [
MLP, MLP,
...@@ -13,7 +15,10 @@ def test_opt(): ...@@ -13,7 +15,10 @@ def test_opt():
] ]
CONFIGS = [ CONFIGS = [
{'dim': 10, 'layers': 12}, {
'dim': 10,
'layers': 12
},
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
] ]
...@@ -39,5 +44,6 @@ def test_opt(): ...@@ -39,5 +44,6 @@ def test_opt():
# print(f'{top_mod=}\n----\n{topo=}') # print(f'{top_mod=}\n----\n{topo=}')
check_topo(top_mod, topo) check_topo(top_mod, topo)
if __name__ == '__main__': if __name__ == '__main__':
test_opt() test_opt()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment