Unverified Commit fb873227 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] fix spelling error (#2270)

parent af32022f
...@@ -6,9 +6,9 @@ from typing import List ...@@ -6,9 +6,9 @@ from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
...@@ -82,13 +82,13 @@ class MatVecStrategyGenerator(StrategyGenerator): ...@@ -82,13 +82,13 @@ class MatVecStrategyGenerator(StrategyGenerator):
class MatMulStrategyGenerator(StrategyGenerator): class MatMulStrategyGenerator(StrategyGenerator):
""" """
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm. a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
A matmul can be formulated as [n, p] x [p, q] = [n, q] A matmul can be formulated as [n, p] x [p, q] = [n, q]
Args: Args:
is_linear (bool): whether this generator is used for nn.Linear and F.linear. is_linear (bool): whether this generator is used for nn.Linear and F.linear.
This will incur extra transformation of the dim partitioning as the weight is transposed. This will incur extra transformation of the dim partitioning as the weight is transposed.
""" """
...@@ -255,7 +255,7 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator): ...@@ -255,7 +255,7 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
""" """
Generate sharding strategies for the batched matrix multiplication. Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j] [b, i, k] x [b, k, j] -> [b, i, j]
""" """
...@@ -431,7 +431,7 @@ class DotHandler(OperatorHandler): ...@@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
...@@ -451,7 +451,7 @@ class DotHandler(OperatorHandler): ...@@ -451,7 +451,7 @@ class DotHandler(OperatorHandler):
# create and register strategy # create and register strategy
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
...@@ -473,7 +473,7 @@ class DotHandler(OperatorHandler): ...@@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
...@@ -491,7 +491,7 @@ class DotHandler(OperatorHandler): ...@@ -491,7 +491,7 @@ class DotHandler(OperatorHandler):
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
...@@ -510,7 +510,7 @@ class DotHandler(OperatorHandler): ...@@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
...@@ -529,7 +529,7 @@ class DotHandler(OperatorHandler): ...@@ -529,7 +529,7 @@ class DotHandler(OperatorHandler):
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
...@@ -548,7 +548,7 @@ class DotHandler(OperatorHandler): ...@@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
...@@ -564,7 +564,7 @@ class DotHandler(OperatorHandler): ...@@ -564,7 +564,7 @@ class DotHandler(OperatorHandler):
# compute the communication cost of this strategy # compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
...@@ -583,7 +583,7 @@ class DotHandler(OperatorHandler): ...@@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]} dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
...@@ -600,7 +600,7 @@ class DotHandler(OperatorHandler): ...@@ -600,7 +600,7 @@ class DotHandler(OperatorHandler):
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim) communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
communication_cost = communication_cost_activation_backward communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
...@@ -619,7 +619,7 @@ class DotHandler(OperatorHandler): ...@@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
...@@ -636,7 +636,7 @@ class DotHandler(OperatorHandler): ...@@ -636,7 +636,7 @@ class DotHandler(OperatorHandler):
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0) communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
communication_cost = communication_cost_weight_backward communication_cost = communication_cost_weight_backward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
...@@ -655,7 +655,7 @@ class DotHandler(OperatorHandler): ...@@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
...@@ -673,7 +673,7 @@ class DotHandler(OperatorHandler): ...@@ -673,7 +673,7 @@ class DotHandler(OperatorHandler):
activation_memory_cost, 0) activation_memory_cost, 0)
communication_cost = communication_cost_forward_activation communication_cost = communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
...@@ -692,7 +692,7 @@ class DotHandler(OperatorHandler): ...@@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
...@@ -709,7 +709,7 @@ class DotHandler(OperatorHandler): ...@@ -709,7 +709,7 @@ class DotHandler(OperatorHandler):
input_grad_memory_cost, 0) input_grad_memory_cost, 0)
communication_cost = communication_cost_activation_backward communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=toatl_memory_cost, memory_cost=toatl_memory_cost,
......
...@@ -5,14 +5,14 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler ...@@ -5,14 +5,14 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
from .experimental import PermuteHandler, ViewHandler from .experimental import PermuteHandler, ViewHandler
from .getatrr_handler import GetattrHandler from .getattr_handler import GetattrHandler
from .getitem_handler import GetItemHandler from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OuputHandler from .output_handler import OutputHandler
from .placeholder_handler import PlacehodlerHandler from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry from .registry import operator_registry
from .reshape_handler import ReshapeHandler from .reshape_handler import ReshapeHandler
from .softmax_handler import SoftmaxHandler from .softmax_handler import SoftmaxHandler
...@@ -24,7 +24,7 @@ from .where_handler import WhereHandler ...@@ -24,7 +24,7 @@ from .where_handler import WhereHandler
__all__ = [ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler' 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
......
...@@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect ...@@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator from .strategy import OutputGenerator, StrategyGenerator
__all__ = ['OuputHandler'] __all__ = ['OutputHandler']
class OuputHandler(NodeHandler): class OutputHandler(NodeHandler):
""" """
A OuputHandler which deals with the sharding strategies for Output Node. A OutputHandler which deals with the sharding strategies for Output Node.
""" """
def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
......
...@@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect ...@@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator from .strategy import PlaceholderGenerator, StrategyGenerator
__all__ = ['PlacehodlerHandler'] __all__ = ['PlaceholderHandler']
class PlacehodlerHandler(NodeHandler): class PlaceholderHandler(NodeHandler):
""" """
A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node. A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
""" """
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
......
...@@ -9,8 +9,8 @@ from torch.fx import Graph, Node ...@@ -9,8 +9,8 @@ from torch.fx import Graph, Node
from colossalai.auto_parallel.tensor_shard.node_handler import ( from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler, GetattrHandler,
OuputHandler, OutputHandler,
PlacehodlerHandler, PlaceholderHandler,
operator_registry, operator_registry,
) )
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
...@@ -93,7 +93,7 @@ class StrategiesConstructor: ...@@ -93,7 +93,7 @@ class StrategiesConstructor:
else: else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
placeholder_option = 'replicated' placeholder_option = 'replicated'
placeholder_handler = PlacehodlerHandler(node, placeholder_handler = PlaceholderHandler(node,
self.device_mesh, self.device_mesh,
strategies_vector, strategies_vector,
placeholder_option=placeholder_option) placeholder_option=placeholder_option)
...@@ -140,7 +140,7 @@ class StrategiesConstructor: ...@@ -140,7 +140,7 @@ class StrategiesConstructor:
else: else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
output_option = 'replicated' output_option = 'replicated'
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy() output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector) self.remove_duplicated_strategy(strategies_vector)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
......
...@@ -7,7 +7,7 @@ import torch.nn as nn ...@@ -7,7 +7,7 @@ import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
...@@ -145,7 +145,7 @@ def test_getitem_from_tuple_handler(): ...@@ -145,7 +145,7 @@ def test_getitem_from_tuple_handler():
split_strategies_vector = StrategiesVector(split_node) split_strategies_vector = StrategiesVector(split_node)
# build handler # build handler
input_handler = PlacehodlerHandler( input_handler = PlaceholderHandler(
node=input_node, node=input_node,
device_mesh=device_mesh, device_mesh=device_mesh,
strategies_vector=input_strategies_vector, strategies_vector=input_strategies_vector,
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
...@@ -39,10 +39,10 @@ def test_output_handler(output_option): ...@@ -39,10 +39,10 @@ def test_output_handler(output_option):
output_strategies_vector = StrategiesVector(output_node) output_strategies_vector = StrategiesVector(output_node)
# build handler # build handler
otuput_handler = OuputHandler(node=output_node, otuput_handler = OutputHandler(node=output_node,
device_mesh=device_mesh, device_mesh=device_mesh,
strategies_vector=output_strategies_vector, strategies_vector=output_strategies_vector,
output_option=output_option) output_option=output_option)
otuput_handler.register_strategy(compute_resharding_cost=False) otuput_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping # check operation data mapping
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
...@@ -36,7 +36,7 @@ def test_placeholder_handler(placeholder_option): ...@@ -36,7 +36,7 @@ def test_placeholder_handler(placeholder_option):
placeholder_node = list(graph.nodes)[0] placeholder_node = list(graph.nodes)[0]
placeholder_strategies_vector = StrategiesVector(placeholder_node) placeholder_strategies_vector = StrategiesVector(placeholder_node)
# build handler # build handler
placeholder_handler = PlacehodlerHandler(node=placeholder_node, placeholder_handler = PlaceholderHandler(node=placeholder_node,
device_mesh=device_mesh, device_mesh=device_mesh,
strategies_vector=placeholder_strategies_vector, strategies_vector=placeholder_strategies_vector,
placeholder_option=placeholder_option) placeholder_option=placeholder_option)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment