Unverified Commit 32f81f14 authored by digger yu's avatar digger yu Committed by GitHub
Browse files

[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)

parent 21e29e22
...@@ -240,7 +240,7 @@ class GradScaler(object): ...@@ -240,7 +240,7 @@ class GradScaler(object):
for grads in per_dtype_grads.values(): for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
per_device_inv_scale.get(device)) per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group # For tensor parallel parameters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()] vals = [val for val in per_device_found_inf._per_device_tensors.values()]
coalesced = _flatten_dense_tensors(vals) coalesced = _flatten_dense_tensors(vals)
......
...@@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ...@@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else: else:
_is_batch_dims_same = False _is_batch_dims_same = False
# retireve dimensions # retrieve dimensions
input_dim_00 = input_tensors[0].shape[-2] input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1] input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2] input_dim_10 = input_tensors[1].shape[-2]
......
...@@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): ...@@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
return gm return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule): def _act_annotation_pass(gm: torch.fx.GraphModule):
""" """
This pass is used to add the act annotation to the new inserted nodes. This pass is used to add the act annotation to the new inserted nodes.
""" """
......
...@@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size], ...@@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
return size return size
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor): strategies_constructor: StrategiesConstructor):
""" """
This method is used to stick the solution strategy to the nodes and add the information This method is used to stick the solution strategy to the nodes and add the information
...@@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule, ...@@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
device_mesh: DeviceMesh, device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor, strategies_constructor: StrategiesConstructor,
overlap=False): overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass( gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
gm, solution, strategies_constructor) gm, solution, strategies_constructor)
gm = size_value_converting_pass(gm, device_mesh) gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh) gm = node_args_converting_pass(gm, device_mesh)
......
...@@ -64,7 +64,7 @@ class TraceFlow(object): ...@@ -64,7 +64,7 @@ class TraceFlow(object):
return False return False
return True return True
def _assgin_single_node_flow( def _assign_single_node_flow(
self, self,
arg_node: Node, arg_node: Node,
start_idx: int, start_idx: int,
...@@ -177,7 +177,7 @@ class TraceFlow(object): ...@@ -177,7 +177,7 @@ class TraceFlow(object):
if get_node_shape(arg) is None: if get_node_shape(arg) is None:
continue continue
arg_list.append(arg) arg_list.append(arg)
flow_flag = self._assgin_single_node_flow( flow_flag = self._assign_single_node_flow(
arg, arg,
start_idx, start_idx,
end_idx, end_idx,
...@@ -315,7 +315,7 @@ class TraceFlow(object): ...@@ -315,7 +315,7 @@ class TraceFlow(object):
chunk_info["args"]["prepose_nodes"] = prepose_nodes chunk_info["args"]["prepose_nodes"] = prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop # we need to log input nodes to avoid deleting them in the loop
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs # also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]: for n in chunk_info["args"]["prepose_nodes"]:
......
...@@ -461,7 +461,7 @@ class TraceIndice(object): ...@@ -461,7 +461,7 @@ class TraceIndice(object):
nodes_in.append(node_in) nodes_in.append(node_in)
self._inherit_more_indice_from_node_with_exclude(node_in, node) self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assgin_no_change_indice(self, node, idx): def _assign_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx) self._assign_indice_as_input(node, idx)
for node_in in node.args: for node_in in node.args:
if type(node_in) == type(node): if type(node_in) == type(node):
...@@ -792,7 +792,7 @@ class TraceIndice(object): ...@@ -792,7 +792,7 @@ class TraceIndice(object):
self._add_dim(node_idx, i) self._add_dim(node_idx, i)
dim_from.reverse() dim_from.reverse()
# inheirt indice from current node # inherit indice from current node
if len(dim_from) != 0 and len(dim_to) != 0: if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1: if dim_diff == 1:
if origin_shape[dim_from[0]] == 1: if origin_shape[dim_from[0]] == 1:
...@@ -852,7 +852,7 @@ class TraceIndice(object): ...@@ -852,7 +852,7 @@ class TraceIndice(object):
elif "split" == node_name: elif "split" == node_name:
self._assign_split_indice(node, idx) self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]): elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
self._assgin_no_change_indice(node, idx) self._assign_no_change_indice(node, idx)
elif "new_ones" == node_name: elif "new_ones" == node_name:
self._assign_all_indice(node, idx) self._assign_all_indice(node, idx)
elif "flatten" == node_name: elif "flatten" == node_name:
...@@ -914,7 +914,7 @@ class TraceIndice(object): ...@@ -914,7 +914,7 @@ class TraceIndice(object):
elif "conv2d" == node_name: elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx) self._assign_conv2d_indice(node, idx)
elif "identity" == node_name: elif "identity" == node_name:
self._assgin_no_change_indice(node, idx) self._assign_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]): elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx) self._assign_elementwise_indice(node, idx)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment