Unverified Commit 1baeb39c authored by digger yu's avatar digger yu Committed by GitHub
Browse files

[NFC] fix typo with colossalai/auto_parallel/tensor_shard (#3742)

* fix typo applications/ and colossalai/ date 5.11

* fix typo colossalai/
parent 7386c666
...@@ -75,7 +75,7 @@ class NodeHandler(ABC): ...@@ -75,7 +75,7 @@ class NodeHandler(ABC):
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
] ]
# create data structrure to store costs # create data structure to store costs
if node not in resharding_costs: if node not in resharding_costs:
resharding_costs[node] = [] resharding_costs[node] = []
......
...@@ -24,7 +24,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -24,7 +24,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
To keep the math consistency, there are two way to do BatchNorm if the input To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension: shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm. 1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help 2. We do the SyncBatchNorm on the each input partition separately, the SyncBN op will help
us to keep the computing correctness. us to keep the computing correctness.
In this generator, both methods will be considered. In this generator, both methods will be considered.
""" """
...@@ -212,7 +212,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -212,7 +212,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action # set communication action
# For SyncBN case, we don't need to do communication for weight and bias. # For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
...@@ -250,7 +250,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -250,7 +250,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action # set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias. # For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
...@@ -298,7 +298,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -298,7 +298,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action # set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias. # For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
......
...@@ -51,7 +51,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator): ...@@ -51,7 +51,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
# compute fwd memory cost in bytes # compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive # as the elementwise ops are not memory-intensive
# we approximate the fwd memroy cost to be the output # we approximate the fwd memory cost to be the output
# and the backward memory cost to be grad of input and other # and the backward memory cost to be grad of input and other
input_bytes = self._compute_size_in_bytes(strategy, 'input') input_bytes = self._compute_size_in_bytes(strategy, 'input')
other_bytes = self._compute_size_in_bytes(strategy, 'other') other_bytes = self._compute_size_in_bytes(strategy, 'other')
......
...@@ -225,7 +225,7 @@ class StrategyGenerator(ABC): ...@@ -225,7 +225,7 @@ class StrategyGenerator(ABC):
if isinstance(meta_data, torch.Tensor): if isinstance(meta_data, torch.Tensor):
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data) element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
else: else:
# if meta_data is not a tensor, we count the memroy as 0 # if meta_data is not a tensor, we count the memory as 0
element_bytes = 0 element_bytes = 0
total_bytes += element_bytes total_bytes += element_bytes
...@@ -233,7 +233,7 @@ class StrategyGenerator(ABC): ...@@ -233,7 +233,7 @@ class StrategyGenerator(ABC):
if isinstance(op_data.data, torch.Tensor): if isinstance(op_data.data, torch.Tensor):
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data) total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
else: else:
# if op_data.data is not a tensor, we count the memroy as 0 # if op_data.data is not a tensor, we count the memory as 0
total_bytes = 0 total_bytes = 0
return total_bytes return total_bytes
......
...@@ -9,7 +9,7 @@ class CostGraph: ...@@ -9,7 +9,7 @@ class CostGraph:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as 2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will element-wise operators, transpose, and reduction, into their following nodes. The merging information will
be given by the StrategiesVector depending on the type of target node and following nodes. be given by the StrategiesVector depending on the type of target node and following nodes.
Argument: Argument:
...@@ -90,7 +90,7 @@ class CostGraph: ...@@ -90,7 +90,7 @@ class CostGraph:
if self.simplify and strategies_vector.check_merge(): if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes: for followed_node in strategies_vector.predecessor_nodes:
# we only merge node pairs which src node has a tensor element inside. # we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not # This is necessary because the node without a tensor element inside will not
# be assigned any strategy. # be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data): if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node)) self.merge_pair.append((followed_node, dst_node))
......
...@@ -83,7 +83,7 @@ class GraphAnalyser: ...@@ -83,7 +83,7 @@ class GraphAnalyser:
def liveness_analysis(self) -> List[LiveStage]: def liveness_analysis(self) -> List[LiveStage]:
""" """
Analyse the graph to obtain the variable liveness information. This function returns Analyses the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
""" """
compute_nodes = self.graph.nodes compute_nodes = self.graph.nodes
...@@ -91,7 +91,7 @@ class GraphAnalyser: ...@@ -91,7 +91,7 @@ class GraphAnalyser:
# checked: record all variables created since the first stage # checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage. # all: record the live variables only exist until the current stage.
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. # this can be different from the `checked list`` as some variables may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage. # unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated. # this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector() checked_variables = LiveVariableVector()
...@@ -103,7 +103,7 @@ class GraphAnalyser: ...@@ -103,7 +103,7 @@ class GraphAnalyser:
# find new living variables # # find new living variables #
############################# #############################
# detect whether the current op is an in-place op # detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplciate var # if it is an in-place op, we would deem it as a duplicate var
is_inplace = False is_inplace = False
if node.op == 'call_function': if node.op == 'call_function':
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
......
...@@ -44,7 +44,7 @@ class Solver: ...@@ -44,7 +44,7 @@ class Solver:
graph: The computing graph to be optimized. graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph. strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph. cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. graph_analyser: graph_analyser will analyses the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution. memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
......
...@@ -33,7 +33,7 @@ def run_on_environment_flag(name: str): ...@@ -33,7 +33,7 @@ def run_on_environment_flag(name: str):
assert isinstance(name, str) assert isinstance(name, str)
flag = os.environ.get(name.upper(), '0') flag = os.environ.get(name.upper(), '0')
reason = f'Environment varialbe {name} is {flag}' reason = f'Environment variable {name} is {flag}'
if flag == '1': if flag == '1':
return pytest.mark.skipif(False, reason=reason) return pytest.mark.skipif(False, reason=reason)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment