Unverified Commit fee2af86 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] adapt autoparallel with new analyzer (#3261)

* [autoparallel] adapt autoparallel with new analyzer

* fix all node handler tests

* polish

* polish
parent e78a1e94
...@@ -446,10 +446,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): ...@@ -446,10 +446,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
@register_meta(aten.embedding_dense_backward.default) @register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq): scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)), return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
# ============================== Dropout =========================================== # ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
......
...@@ -51,7 +51,10 @@ def _normalize_tuple(x): ...@@ -51,7 +51,10 @@ def _normalize_tuple(x):
def _current_device(module): def _current_device(module):
return next(module.parameters()).device try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
...@@ -120,15 +123,18 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -120,15 +123,18 @@ class ShapeProp(torch.fx.Interpreter):
return t.to('meta') return t.to('meta')
if isinstance(elem, MetaTensor): if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor) return _convert_meta(elem._tensor)
elif isinstance(elem, torch.Tensor): elif isinstance(elem, torch.Tensor):
if isinstance(elem, torch.nn.Parameter):
return torch.nn.Parameter(_convert_meta(elem))
return _convert_meta(elem) return _convert_meta(elem)
else: else:
return elem return elem
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter) is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n) n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r) n_info.outputs = _normalize_tuple(r)
...@@ -149,7 +155,11 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -149,7 +155,11 @@ class ShapeProp(torch.fx.Interpreter):
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \ n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v)) tuple(v for v in kwargs.values() if is_pure_tensor(v))
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD # align with SPMD
if isinstance(r, (tuple, list)):
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
else:
n._meta_data = unwrap_fn(r)
n_info.global_ctx = self.global_hook.ctx n_info.global_ctx = self.global_hook.ctx
n_info.curr_ctx = self.global_hook.ctx.copy() n_info.curr_ctx = self.global_hook.ctx.copy()
...@@ -175,10 +185,48 @@ class ShapeProp(torch.fx.Interpreter): ...@@ -175,10 +185,48 @@ class ShapeProp(torch.fx.Interpreter):
Return Return
Any: The value returned by the function invocation Any: The value returned by the function invocation
""" """
convert_to_param = False
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
convert_to_param = True
if target in self._custom_dispatch_func: if target in self._custom_dispatch_func:
return self._custom_dispatch_func[target](*args, **kwargs) res = self._custom_dispatch_func[target](*args, **kwargs)
else:
res = super().call_function(target, args, kwargs)
if convert_to_param:
return torch.nn.Parameter(res)
else:
return res
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the method invocation
"""
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
target_method = getattr(self_obj.__class__, target)
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
res = getattr(self_obj, target)(*args_tail, **kwargs)
if convert_to_parameter:
return torch.nn.Parameter(res)
else: else:
return super().call_function(target, args, kwargs) return res
def propagate(self, *args, device=None): def propagate(self, *args, device=None):
""" """
......
...@@ -21,111 +21,69 @@ def linear_impl(input, weight, bias=None): ...@@ -21,111 +21,69 @@ def linear_impl(input, weight, bias=None):
@register_tracer_impl(F.conv1d, name='_bias_addition_impl') @register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): def conv1d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None: if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv1d(input, weight, **kwargs)
else: else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( new_kwargs = kwargs
(-1, 1)) new_kwargs['bias'] = None
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv2d, name='_bias_addition_impl') @register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): def conv2d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None: if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv2d(input, weight, **kwargs)
else: else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( new_kwargs = kwargs
(-1, 1, 1)) new_kwargs['bias'] = None
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv3d, name='_bias_addition_impl') @register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): def conv3d_impl(input, weight, **kwargs):
bias = getattr(kwargs, 'bias', None)
if bias is None: if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) return F.conv3d(input, weight, **kwargs)
else: else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( new_kwargs = kwargs
(-1, 1, 1, 1)) new_kwargs['bias'] = None
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl') @register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input, def conv_transpose1d_impl(input, weight, **kwargs):
weight, bias = getattr(kwargs, 'bias', None)
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
if bias is None: if bias is None:
return F.conv_transpose1d(input, return F.conv_transpose1d(input, weight, **kwargs)
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else: else:
return F.conv_transpose1d(input, new_kwargs = kwargs
weight, new_kwargs['bias'] = None
stride=stride, return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl') @register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input, def conv_transpose2d_impl(input, weight, **kwargs):
weight, bias = getattr(kwargs, 'bias', None)
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
if bias is None: if bias is None:
return F.conv_transpose2d(input, return F.conv_transpose2d(input, weight, **kwargs)
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else: else:
return F.conv_transpose2d(input, new_kwargs = kwargs
weight, new_kwargs['bias'] = None
stride=stride, return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl') @register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input, def conv_transpose3d_impl(input, weight, **kwargs):
weight, bias = getattr(kwargs, 'bias', None)
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
if bias is None: if bias is None:
return F.conv_transpose3d(input, return F.conv_transpose3d(input, weight, **kwargs)
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else: else:
return F.conv_transpose3d(input, new_kwargs = kwargs
weight, new_kwargs['bias'] = None
stride=stride, return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl') @register_tracer_impl(torch.addmm, name='_bias_addition_impl')
......
...@@ -70,14 +70,28 @@ class MetaInfo: ...@@ -70,14 +70,28 @@ class MetaInfo:
if self._strategy is not None and self._target is not None: if self._strategy is not None and self._target is not None:
self.compute_metainfo() self.compute_metainfo()
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor: def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
""" """
Compute sharded opdata based on the given data and sharding spec. Compute sharded opdata based on the given data and sharding spec.
""" """
return OperationData(name=operation_data.name,
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), if isinstance(sharding_spec, ShardingSpec):
type=operation_data.type, op_data = OperationData(name=operation_data.name,
logical_shape=operation_data.logical_shape) data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
type=operation_data.type,
logical_shape=operation_data.logical_shape)
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same."
sharded_data = []
for d, s in zip(data, sharding_spec):
sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta"))
op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type)
else:
raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.")
return op_data
def compute_metainfo(self): def compute_metainfo(self):
""" """
......
...@@ -387,12 +387,13 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -387,12 +387,13 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# This stream is created for overlaping the communication and computation. # This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream() reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param): def _add_hook_for_grad_communication(node, param, name=None):
comm_actions = node.best_strategy.communication_actions comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action): def _filter_param_to_hook(node, op_data, comm_action, name):
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
return True return True
if node.op == 'get_attr' and isinstance( if node.op == 'get_attr' and isinstance(
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
...@@ -402,7 +403,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -402,7 +403,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for operation_data, comm_action in comm_actions.items(): for operation_data, comm_action in comm_actions.items():
comm_spec_to_use = comm_action.comm_spec comm_spec_to_use = comm_action.comm_spec
# register hook to the parameters # register hook to the parameters
if _filter_param_to_hook(node, operation_data, comm_action): if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap): def wrapper(param, comm_spec, stream, overlap):
...@@ -442,7 +443,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -442,7 +443,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
param = _shard_param(param, target_sharding_spec) param = _shard_param(param, target_sharding_spec)
setattr(target_module, name, param) setattr(target_module, name, param)
_add_hook_for_grad_communication(node, param) _add_hook_for_grad_communication(node, param, name)
sharded_buffer_dict = {} sharded_buffer_dict = {}
# apply the sharding spec of buffers # apply the sharding spec of buffers
......
...@@ -81,7 +81,10 @@ class AddBMMFunctionHandler(NodeHandler): ...@@ -81,7 +81,10 @@ class AddBMMFunctionHandler(NodeHandler):
def get_strategy_generator(self) -> List[StrategyGenerator]: def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)) generator = BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)
# addbmm will shrink the first batch dim
generator.squeeze_batch_dim = True
generators.append(generator)
return generators return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
......
...@@ -776,10 +776,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -776,10 +776,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
bias_op_data = self.op_data['bias'] bias_op_data = self.op_data['bias']
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2 assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
if self.op_data['output'].data.dim() == 2:
# addbmm will shrink the first batch dim
self.squeeze_batch_dim = True
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul, fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
self.op_data['output'].data.shape) self.op_data['output'].data.shape)
......
...@@ -386,7 +386,7 @@ def meta_local_scalar_dense(self: torch.Tensor): ...@@ -386,7 +386,7 @@ def meta_local_scalar_dense(self: torch.Tensor):
@register_meta(aten.where.self) @register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other) result_type = torch.result_type(self, other)
return torch.empty_like(self, dtype=result_type) return torch.empty_like(condition + self + other, dtype=result_type)
@register_meta(aten.index.Tensor) @register_meta(aten.index.Tensor)
......
from faulthandler import disable
from functools import partial from functools import partial
from xml.dom import WrongDocumentErr
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType, OperationDataType,
ShardingStrategy, ShardingStrategy,
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
...@@ -96,7 +94,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port) ...@@ -96,7 +94,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
node_type='bias_module') node_type='bias_module')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %m1 : torch.Tensor [#users=1] = placeholder[target=m1] # %m1 : torch.Tensor [#users=1] = placeholder[target=m1]
...@@ -109,6 +107,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port) ...@@ -109,6 +107,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
# return add # return add
graph = tracer.trace(model, meta_args=meta_args_for_tracer) graph = tracer.trace(model, meta_args=meta_args_for_tracer)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args_for_tracer.values())
# [input_1, m1, m2, addmm, output] # [input_1, m1, m2, addmm, output]
node_list = list(graph.nodes) node_list = list(graph.nodes)
linear_node = node_list[4] linear_node = node_list[4]
......
...@@ -5,10 +5,12 @@ import torch ...@@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -38,13 +40,15 @@ def check_bn_module_handler(rank, world_size, port): ...@@ -38,13 +40,15 @@ def check_bn_module_handler(rank, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=[input], input_args=[input],
meta_arg_names=['input']) meta_arg_names=['input'])
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0 # return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')}) meta_args = {"input": torch.rand(4, 16, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
bn_mod_node = list(graph.nodes)[1] bn_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(bn_mod_node) strategies_vector = StrategiesVector(bn_mod_node)
......
from faulthandler import disable
from functools import partial from functools import partial
from xml.dom import WrongDocumentErr
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
...@@ -17,12 +17,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -17,12 +17,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -66,7 +64,7 @@ def check_linear_module_handler(rank, world_size, port): ...@@ -66,7 +64,7 @@ def check_linear_module_handler(rank, world_size, port):
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
node_type='bias_module') node_type='bias_module')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x] # %x : torch.Tensor [#users=1] = placeholder[target=x]
# %weight : [#users=1] = get_attr[target=weight] # %weight : [#users=1] = get_attr[target=weight]
...@@ -74,8 +72,10 @@ def check_linear_module_handler(rank, world_size, port): ...@@ -74,8 +72,10 @@ def check_linear_module_handler(rank, world_size, port):
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {}) # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {})
# return add # return add
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[3] linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)
......
from faulthandler import disable
from functools import partial from functools import partial
from xml.dom import WrongDocumentErr
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from typing_extensions import Self
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
...@@ -16,12 +16,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -16,12 +16,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -62,9 +60,11 @@ def check_linear_module_handler(rank, bias, world_size, port): ...@@ -62,9 +60,11 @@ def check_linear_module_handler(rank, bias, world_size, port):
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
node_type='bias_module') node_type='bias_module')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')}) meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[3] linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)
......
...@@ -5,10 +5,12 @@ import torch ...@@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -52,10 +54,11 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size ...@@ -52,10 +54,11 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')}
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
op_node = list(graph.nodes)[2] op_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(op_node) strategies_vector = StrategiesVector(op_node)
...@@ -172,12 +175,11 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo ...@@ -172,12 +175,11 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
meta_args = {'x1': torch.rand(4, 4).to('meta')} meta_args = {'x1': torch.rand(4, 4).to('meta')}
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
print(graph)
# assert False
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if model_cls == BEOpModelWithNodeConst: if model_cls == BEOpModelWithNodeConst:
op_node = list(graph.nodes)[2] op_node = list(graph.nodes)[2]
......
...@@ -5,10 +5,12 @@ import torch ...@@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -52,13 +54,11 @@ def check_2d_device_mesh(rank, module, world_size, port): ...@@ -52,13 +54,11 @@ def check_2d_device_mesh(rank, module, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[2] linear_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)
...@@ -147,13 +147,11 @@ def check_1d_device_mesh(rank, module, world_size, port): ...@@ -147,13 +147,11 @@ def check_1d_device_mesh(rank, module, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')}
meta_args={ graph = tracer.trace(model, meta_args=meta_args)
"x1": torch.rand(4, 8, 16).to('meta'),
'x2': torch.rand(4, 16, 8).to('meta')
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
linear_mod_node = list(graph.nodes)[2] linear_mod_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(linear_mod_node) strategies_vector = StrategiesVector(linear_mod_node)
...@@ -205,6 +203,7 @@ def check_1d_device_mesh(rank, module, world_size, port): ...@@ -205,6 +203,7 @@ def check_1d_device_mesh(rank, module, world_size, port):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) @parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bmm_handler(module): def test_bmm_handler(module):
......
...@@ -5,10 +5,12 @@ import torch ...@@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -41,9 +43,11 @@ def check_conv_module_handler(rank, bias, world_size, port): ...@@ -41,9 +43,11 @@ def check_conv_module_handler(rank, bias, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=[input], input_args=[input],
meta_arg_names=['input']) meta_arg_names=['input'])
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')}) meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
conv_mod_node = list(graph.nodes)[1] conv_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(conv_mod_node) strategies_vector = StrategiesVector(conv_mod_node)
...@@ -178,7 +182,7 @@ def check_conv_function_handler(rank, bias, world_size, port): ...@@ -178,7 +182,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
input_kwargs=input_kwargs) input_kwargs=input_kwargs)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others] # %others : torch.Tensor [#users=1] = placeholder[target=others]
...@@ -189,6 +193,7 @@ def check_conv_function_handler(rank, bias, world_size, port): ...@@ -189,6 +193,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
meta_args['bias'] = torch.rand(16).to('meta') meta_args['bias'] = torch.rand(16).to('meta')
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
if bias: if bias:
conv_mod_node = list(graph.nodes)[3] conv_mod_node = list(graph.nodes)[3]
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
...@@ -23,19 +25,20 @@ class ReshapeModel(nn.Module): ...@@ -23,19 +25,20 @@ class ReshapeModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
def test_reshape_handler(): def test_reshape_handler():
model = ReshapeModel() model = ReshapeModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other] # %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view # return view
graph = tracer.trace(model, meta_args = {
meta_args={ "input": torch.rand(4, 4, 64, 64).to('meta'),
"input": torch.rand(4, 4, 64, 64).to('meta'), "other": torch.rand(16, 4, 3, 3).to('meta'),
"other": torch.rand(4, 16, 3, 3).to('meta'), }
}) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
...@@ -67,13 +70,13 @@ def test_reshape_handler(): ...@@ -67,13 +70,13 @@ def test_reshape_handler():
assert mapping['input'].name == "conv2d" assert mapping['input'].name == "conv2d"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62])
assert mapping['output'].name == "view" assert mapping['output'].name == "view"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([2, 30752]) assert mapping['output'].data.shape == torch.Size([2, 123008])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
......
...@@ -5,13 +5,15 @@ import torch ...@@ -5,13 +5,15 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import ( from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import (
EmbeddingFunctionHandler, EmbeddingFunctionHandler,
EmbeddingModuleHandler, EmbeddingModuleHandler,
) )
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
...@@ -60,9 +62,11 @@ def check_embedding_module_handler(rank, world_size, port): ...@@ -60,9 +62,11 @@ def check_embedding_module_handler(rank, world_size, port):
input_args=[input], input_args=[input],
meta_arg_names=['input']) meta_arg_names=['input'])
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 16).to('meta')}) meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
embedding_node = list(graph.nodes)[1] embedding_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(embedding_node) strategies_vector = StrategiesVector(embedding_node)
...@@ -171,18 +175,19 @@ def check_embedding_function_handler(rank, world_size, port): ...@@ -171,18 +175,19 @@ def check_embedding_function_handler(rank, world_size, port):
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names, meta_arg_names=meta_arg_names,
input_kwargs=input_kwargs) input_kwargs=input_kwargs)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others] # %others : torch.Tensor [#users=1] = placeholder[target=others]
# %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) # %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False})
# return embedding # return embedding
meta_args = { meta_args = {
"input": torch.rand(4, 16, 16).to('meta'), "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta'),
"others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta') "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta')
} }
graph = tracer.trace(model, meta_args=meta_args) graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
embedding_node = list(graph.nodes)[2] embedding_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(embedding_node) strategies_vector = StrategiesVector(embedding_node)
......
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class GetattrModel(nn.Module): class GetattrModel(nn.Module):
...@@ -18,15 +21,18 @@ class GetattrModel(nn.Module): ...@@ -18,15 +21,18 @@ class GetattrModel(nn.Module):
return weight return weight
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
def test_getattr_handler(): def test_getattr_handler():
model = GetattrModel() model = GetattrModel()
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=0] = placeholder[target=input] # %input_1 : torch.Tensor [#users=0] = placeholder[target=input]
# %conv_weight : [#users=1] = get_attr[target=conv.weight] # %conv_weight : [#users=1] = get_attr[target=conv.weight]
# return conv_weight # return conv_weight
graph = tracer.trace(model, meta_args={'input': torch.rand(4, 4, 64, 64).to('meta')}) meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
......
...@@ -5,13 +5,15 @@ import torch ...@@ -5,13 +5,15 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
...@@ -58,15 +60,15 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): ...@@ -58,15 +60,15 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
meta_arg_names=['input', 'other'], meta_arg_names=['input', 'other'],
node_type='following') node_type='following')
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
meta_args = {
graph = tracer.trace(model, "input": torch.rand(8, 16, 64, 32).to('meta'),
meta_args={ "other": torch.rand(64, 32).to('meta'),
"input": torch.rand(8, 16, 64, 32).to('meta'), }
"other": torch.rand(64, 32).to('meta'), graph = tracer.trace(model, meta_args=meta_args)
})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *list(meta_args.values()))
linear_mod_node = list(graph.nodes)[2] linear_mod_node = list(graph.nodes)[2]
getitem_mod_node = list(graph.nodes)[3] getitem_mod_node = list(graph.nodes)[3]
getitem_strategies_vector = StrategiesVector(getitem_mod_node) getitem_strategies_vector = StrategiesVector(getitem_mod_node)
...@@ -129,10 +131,12 @@ def test_getitem_from_tuple_handler(): ...@@ -129,10 +131,12 @@ def test_getitem_from_tuple_handler():
# %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0}) # %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
# return getitem # return getitem
graph = tracer.trace(model, meta_args={ meta_args = {
"input": torch.rand(4, 4, 64, 64).to('meta'), "input": torch.rand(4, 4, 64, 64).to('meta'),
}) }
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
...@@ -5,10 +5,12 @@ import torch ...@@ -5,10 +5,12 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
...@@ -40,13 +42,15 @@ def check_ln_module_handler(rank, world_size, port): ...@@ -40,13 +42,15 @@ def check_ln_module_handler(rank, world_size, port):
strategy_number=strategy_number, strategy_number=strategy_number,
input_args=input_args, input_args=input_args,
meta_arg_names=meta_arg_names) meta_arg_names=meta_arg_names)
tracer = ColoTracer() tracer = ColoTracer(bias_addition_split=True)
# graph(): # graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0 # return _0
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) meta_args = {"input": torch.rand(4, 16).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
shape_prop_pass(gm, *meta_args.values())
ln_mod_node = list(graph.nodes)[1] ln_mod_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(ln_mod_node) strategies_vector = StrategiesVector(ln_mod_node)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment