Unverified Commit cb2c6a24 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] refactor runtime pass (#2644)

* [autoparallel] refactor runtime pass

* add unit test

* polish
parent 89f8975f
...@@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [ ...@@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [
torch.nn.ReLU, torch.nn.ReLU,
torch.nn.Softmax, torch.nn.Softmax,
] ]
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
# This list could be extended if any other method has the same
# argument style as view and reshape.
SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class TestModule(torch.nn.Module):
def forward(self, x):
x = x.view(4, 4, 2)
return x
def insert_narrow(gm, x_node):
graph = gm.graph
with graph.inserting_after(x_node):
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
view_node = list(x_node.users.keys())[0]
new_args = list(view_node.args)
new_args[0] = shard_node
view_node.args = tuple(new_args)
return gm
def test_node_args_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
view_node = list(graph.nodes)[1]
sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', sharding_spec)
setattr(view_node, 'sharding_spec', sharding_spec)
gm = ColoGraphModule(model, graph)
gm = node_args_converting_pass(gm, device_mesh)
gm = insert_narrow(gm, x_node)
gm.recompile()
output = gm(input)
assert output.shape == torch.Size([2, 4, 2])
if __name__ == '__main__':
test_node_args_converting_pass()
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
class TestModule(torch.nn.Module):
def forward(self, x):
size = x.size()
return size
def insert_narrow(gm, x_node):
graph = gm.graph
with graph.inserting_after(x_node):
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
size_node = list(x_node.users.keys())[0]
size_node.args = (shard_node,)
return gm
def recover_narrow(gm, narrow_node):
graph = gm.graph
size_node = list(graph.nodes)[2]
x_node = narrow_node.args[0]
size_node.args = (x_node,)
graph.erase_node(narrow_node)
return gm
def test_size_value_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
meta_args = {'x': torch.rand(4, 8).to('meta')}
input = torch.rand(4, 8)
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_args)
x_node = list(graph.nodes)[0]
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
setattr(x_node, 'sharding_spec', x_sharding_spec)
gm = ColoGraphModule(model, graph)
gm = insert_narrow(gm, x_node)
gm.recompile()
size = gm(input)
assert size == torch.Size([2, 8])
narrow_node = list(gm.graph.nodes)[1]
gm = recover_narrow(gm, narrow_node)
gm = size_value_converting_pass(gm, device_mesh)
gm = insert_narrow(gm, x_node)
gm.recompile()
size = gm(input)
assert size == torch.Size([4, 8])
if __name__ == '__main__':
test_size_value_converting_pass()
from faulthandler import disable
from functools import partial from functools import partial
from xml.dom import WrongDocumentErr
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from typing_extensions import Self
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment