Unverified Commit cb2c6a24 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] refactor runtime pass (#2644)

* [autoparallel] refactor runtime pass

* add unit test

* polish
parent 89f8975f
......@@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
# This list could be extended if any other method has the same
# argument style as view and reshape.
SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class TestModule(torch.nn.Module):
def forward(self, x):
x = x.view(4, 4, 2)
return x
def insert_narrow(gm, x_node):
graph = gm.graph
with graph.inserting_after(x_node):
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
view_node = list(x_node.users.keys())[0]
new_args = list(view_node.args)
new_args[0] = shard_node
view_node.args = tuple(new_args)
return gm
def test_node_args_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
view_node = list(graph.nodes)[1]
sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', sharding_spec)
setattr(view_node, 'sharding_spec', sharding_spec)
gm = ColoGraphModule(model, graph)
gm = node_args_converting_pass(gm, device_mesh)
gm = insert_narrow(gm, x_node)
gm.recompile()
output = gm(input)
assert output.shape == torch.Size([2, 4, 2])
if __name__ == '__main__':
test_node_args_converting_pass()
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class TestModule(torch.nn.Module):
def forward(self, x):
size = x.size()
return size
def insert_narrow(gm, x_node):
graph = gm.graph
with graph.inserting_after(x_node):
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
size_node = list(x_node.users.keys())[0]
size_node.args = (shard_node,)
return gm
def recover_narrow(gm, narrow_node):
graph = gm.graph
size_node = list(graph.nodes)[2]
x_node = narrow_node.args[0]
size_node.args = (x_node,)
graph.erase_node(narrow_node)
return gm
def test_size_value_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', x_sharding_spec)
gm = ColoGraphModule(model, graph)
gm = insert_narrow(gm, x_node)
gm.recompile()
size = gm(input)
assert size == torch.Size([2, 8])
narrow_node = list(gm.graph.nodes)[1]
gm = recover_narrow(gm, narrow_node)
gm = size_value_converting_pass(gm, device_mesh)
gm = insert_narrow(gm, x_node)
gm.recompile()
size = gm(input)
assert size == torch.Size([4, 8])
if __name__ == '__main__':
test_size_value_converting_pass()
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment