Unverified Commit e859380b authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[fx] support module with bias addition (#1780)

* [autoparallel] refactor tracer to fix bias addition issue

* [fx] support module with bias addition

* create bias_addition_module

* refactor file structure

* polish code

* fix unit test
parent f3f19a5c
import torch import torch
from torch.fx import symbolic_trace from torch.fx import symbolic_trace
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module from colossalai.fx.passes.split_module import split_module
...@@ -37,6 +37,21 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): ...@@ -37,6 +37,21 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
else: else:
with mod_graph.inserting_after(node): with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split) split_node = mod_graph.create_node('call_function', pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
if node.op == 'placeholder':
continue
elif node_counter == 0:
node_counter += 1
else:
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile() gm.recompile()
return gm return gm
......
from .tracer import ColoTracer from colossalai.fx.tracer.meta_patch.patched_function.python_ops import operator_getitem
from ._meta_trace import meta_trace
from ._meta_trace import meta_trace
from .tracer import ColoTracer
from .patched_bias_addition_function import *
from .patched_bias_addition_module import *
from .bias_addition_module import *
from .conv import *
from .linear import *
import operator
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
class BiasAdditionModule(ABC):
"""
This class is used to construct the restructure computation graph for
call_module node with bias addition inside.
"""
def __init__(self, tracer, target, args, kwargs, substitute_func):
self.tracer = tracer
self.target = target
self.args = args
self.kwargs = kwargs
self.substitute_func = substitute_func
self.weight_proxy = self._create_weight_proxy()
self.bias_proxy = self._create_bias_proxy()
def _create_weight_proxy(self):
"""
Create weight proxy, the node created by this proxy contains module weight.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
weight_node_kind = 'get_attr'
weight_node_target = self.target + '.weight'
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
return weight_proxy
def _create_bias_proxy(self):
"""
Create bias proxy, the node created by this proxy contains module bias.
Note: this function will be invoked during module initializing,
you should never call this function.
"""
bias_node_kind = 'get_attr'
bias_node_target = self.target + '.bias'
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
return bias_proxy
@abstractmethod
def extract_kwargs_from_mod(self):
"""
This method is used to extract the kwargs for non-bias computation.
For example:
The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are
considered during module initilizing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
pass
def create_non_bias_func_proxy(self, input_proxy=None):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
node_kind = 'call_function'
node_target = self.substitute_func
if input_proxy is None:
input_proxy = self.args[0]
node_args = (input_proxy, self.weight_proxy)
node_kwargs = self.extract_kwargs_from_mod()
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy):
"""
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
bias_add_node_kind = 'call_function'
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
return bias_add_proxy
@abstractmethod
def generate(self):
"""
This method is used to construct the whole restructure computation graph for call_module node with bias
addition inside.
A whole restructure computation graph will contain a weight node, a bias node, a non-bias addition computation node,
a bias reshape node if needed and a bias addition node.
Use Conv2d module as an example:
The origin node is:
%conv: call_module[target=conv](args = (%x,), kwargs = {})
Restructured graph is:
%conv_weight : [#users=1] = get_attr[target=conv.weight]
%conv_bias : [#users=1] = get_attr[target=conv.bias]
%conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
pass
module_to_func_dict = {
torch.nn.Linear: F.linear,
torch.nn.Conv1d: F.conv1d,
torch.nn.Conv2d: F.conv2d,
torch.nn.Conv3d: F.conv3d,
}
import torch
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Conv1d)
@bias_addition_module.register(torch.nn.Conv2d)
@bias_addition_module.register(torch.nn.Conv3d)
class BiasAdditionConv(BiasAdditionModule):
def extract_kwargs_from_mod(self):
root = self.tracer.root
conv_module = root.get_submodule(self.target)
kwarg_attributes = ['groups', 'dilation', 'stride']
non_bias_kwargs = {}
for attr_name in kwarg_attributes:
if hasattr(conv_module, attr_name):
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
if conv_module.padding_mode != "zeros":
conv_type = type(conv_module)
if conv_type == "torch.nn.Conv1d":
padding_element = _single(0)
elif conv_type == "torch.nn.Conv2d":
padding_element = _pair(0)
elif conv_type == "torch.nn.Conv3d":
padding_element = _triple(0)
non_bias_kwargs['padding'] = padding_element
else:
non_bias_kwargs['padding'] = getattr(conv_module, 'padding')
return non_bias_kwargs
def create_bias_reshape_proxy(self, dimensions):
"""
This method is used to reshape the bias node in order to make bias and
output of non-bias convolution broadcastable.
"""
bias_shape = [1] * dimensions
bias_shape[1] = -1
bias_reshape_node_kind = 'call_method'
bias_reshape_node_target = 'view'
bias_reshape_node_args = (self.bias_proxy, bias_shape)
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
bias_reshape_node_args, {})
return bias_reshape_proxy
def generate(self):
non_bias_conv_func_proxy = self.create_non_bias_func_proxy()
output_dims = non_bias_conv_func_proxy.meta_data.dim()
bias_reshape_proxy = self.create_bias_reshape_proxy(output_dims)
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_conv_func_proxy, bias_reshape_proxy)
return bias_addition_proxy
import torch
import torch.nn.functional as F
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Linear)
class BiasAdditionLinear(BiasAdditionModule):
def extract_kwargs_from_mod(self):
return {}
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, self.bias_proxy)
return bias_addition_proxy
from .registry import *
from .patched_function import * from .patched_function import *
from .patched_module import * from .patched_module import *
from .activation_function import * from .activation_function import *
from .arithmetic import * from .arithmetic import *
from .convolution import *
from .embedding import * from .embedding import *
from .normalization import * from .normalization import *
from .python_ops import *
from .torch_ops import * from .torch_ops import *
from .convolution import *
\ No newline at end of file
import torch import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.relu) @meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False): def torch_nn_func_relu(input, inplace=False):
return torch.empty(input.shape, device='meta') return torch.empty(input.shape, device='meta')
\ No newline at end of file
import torch import torch
from ..registry import meta_patched_function from ...registry import meta_patched_function
@meta_patched_function.register(torch.matmul) @meta_patched_function.register(torch.matmul)
...@@ -57,6 +57,16 @@ def torch_bmm(input, mat2, *, out=None): ...@@ -57,6 +57,16 @@ def torch_bmm(input, mat2, *, out=None):
return torch.empty(batch_size, n, p, device="meta") return torch.empty(batch_size, n, p, device="meta")
@meta_patched_function.register(torch.nn.functional.linear)
def torch_linear(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
output_shape = list(input.shape)
output_feature = list(mat2.shape)[0]
output_shape[-1] = output_feature
return torch.empty(*output_shape, device="meta")
@meta_patched_function.register(torch.addbmm) @meta_patched_function.register(torch.addbmm)
@meta_patched_function.register(torch.Tensor.addbmm) @meta_patched_function.register(torch.Tensor.addbmm)
def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
......
import torch
import collections import collections
from itertools import repeat
from ..registry import meta_patched_function
import math import math
from itertools import repeat
import torch
from ...registry import meta_patched_function
def _ntuple(n, name="parse"): def _ntuple(n, name="parse"):
......
import torch import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.embedding) @meta_patched_function.register(torch.nn.functional.embedding)
...@@ -10,4 +11,4 @@ def torch_nn_functional_embedding(input, ...@@ -10,4 +11,4 @@ def torch_nn_functional_embedding(input,
norm_type=2.0, norm_type=2.0,
scale_grad_by_freq=False, scale_grad_by_freq=False,
sparse=False): sparse=False):
return torch.empty(*input.shape, weight.shape[-1], device="meta") return torch.empty(*input.shape, weight.shape[-1], device="meta")
\ No newline at end of file
import torch import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.layer_norm) @meta_patched_function.register(torch.nn.functional.layer_norm)
...@@ -16,4 +17,4 @@ def torch_nn_func_batchnorm(input, ...@@ -16,4 +17,4 @@ def torch_nn_func_batchnorm(input,
training=False, training=False,
momentum=0.1, momentum=0.1,
eps=1e-05): eps=1e-05):
return torch.empty(input.shape, device='meta') return torch.empty(input.shape, device='meta')
\ No newline at end of file
import operator import operator
import torch import torch
from ..registry import meta_patched_function
from colossalai.fx.proxy import ColoProxy from colossalai.fx.proxy import ColoProxy
from ...registry import meta_patched_function
@meta_patched_function.register(operator.getitem) @meta_patched_function.register(operator.getitem)
def operator_getitem(a, b): def operator_getitem(a, b):
......
import torch import torch
from ..registry import meta_patched_function
from ...registry import meta_patched_function
@meta_patched_function.register(torch.arange) @meta_patched_function.register(torch.arange)
......
import torch import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.ReLU) @meta_patched_module.register(torch.nn.ReLU)
......
import math import math
import torch import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Conv1d) @meta_patched_module.register(torch.nn.Conv1d)
......
import torch import torch
from ..registry import meta_patched_module
from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Embedding) @meta_patched_module.register(torch.nn.Embedding)
def torch_nn_embedding(self, input): def torch_nn_embedding(self, input):
result_shape = input.shape + (self.embedding_dim,) result_shape = input.shape + (self.embedding_dim,)
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device='meta')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment