"vscode:/vscode.git/clone" did not exist on "8accecd55bf1a5aaaeb4b84c06fac0d63850fd5e"
Unverified Commit 47b11c43 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel]add bcast matmul strategies (#1605)

parent edb67cb3
......@@ -14,7 +14,7 @@ ELEMENTWISE_FUNC_OP = [
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv
operator.mul, operator.floordiv, operator.truediv, torch.matmul
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
......
......@@ -11,7 +11,6 @@ from enum import Enum
from .strategy_generator import StrategyGenerator, IntermediateStrategy
from typing import List
__all__ = ['DotHandler']
......@@ -465,7 +464,7 @@ class DotHandler(OperatorHandler):
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
......
......@@ -50,6 +50,15 @@ class StrategiesConstructor:
for strategy in remove_list:
strategies_vector.remove(strategy)
def _is_bcast_matmul(self, node):
is_bcast_matmul = False
if node.target is torch.matmul and len(node.args) == 2:
lhs_data = node.args[0]._meta_data
rhs_data = node.args[1]._meta_data
if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
is_bcast_matmul = True
return is_bcast_matmul
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
......@@ -222,7 +231,7 @@ class StrategiesConstructor:
conv_handler.register_strategy()
# linear function
elif target in LINEAR_FUNC_OP:
elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
......
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
class MatmulModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
x = torch.matmul(x1, x2)
return x
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer()
model = MatmulModel()
input_sample = {'x1': torch.rand(4, 4, 8).to('meta'), 'x2': torch.rand(4, 1, 8, 4).to('meta')}
# graph():
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %matmul : [#users=1] = call_function[target=torch.matmul](args = (%x1, %x2), kwargs = {})
# return matmul
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
# [x1, x2, matmul, output]
nodes = [node for node in gm.graph.nodes]
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
strategy_map = strategies_constructor.strategy_map
matmul_strategies = strategy_map[nodes[2]]
assert len(matmul_strategies) == 30
if __name__ == '__main__':
test_conv_handler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment