Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
......@@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
input_node.name)
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
shape_str,
input_node.name,
input_node.name,
)
tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
input_node.name, input_node.name)
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
chunk_output[i].name,
shape_str,
input_node.name,
input_node.name,
)
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_output_dim[0]]
......@@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
return context
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
def _gen_loop_end(
chunk_inputs: List[Node],
chunk_non_compute_inputs: List[Node],
node_list: List[Node],
chunk_outputs_idx: int,
chunk_outputs_non_tensor: List[Node],
search_chunk: SearchChunk,
) -> str:
"""
Generate chunk loop end
......@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
if (
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body
......@@ -203,11 +229,12 @@ def _add_node_slice(
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
get_node_shape(chunk_node))
chunk_slice = _gen_chunk_slice_dim(
chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
)
if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])):
for i in range(len(chunk_node.meta["tensor_meta"])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
......@@ -216,13 +243,15 @@ def _add_node_slice(
return body
def emit_code_with_chunk(body: List[str],
nodes: Iterable[Node],
emit_node_func: Callable,
delete_unused_value_func: Callable,
search_chunk: SearchChunk,
chunk_infos: List,
eval_mem: bool = False):
def emit_code_with_chunk(
body: List[str],
nodes: Iterable[Node],
emit_node_func: Callable,
delete_unused_value_func: Callable,
search_chunk: SearchChunk,
chunk_infos: List,
eval_mem: bool = False,
):
"""
Emit code with chunk according to chunk_infos.
......@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str],
chunk_ends = [i["region"][1] for i in chunk_infos]
# chunk inputs
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
......@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"],
))
)
)
if within_chunk_region:
emit_node_func(node, body)
......@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
% (node.name)
)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
......@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
% (node.name)
)
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
_gen_loop_end(
chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
node_list,
chunk_ends[region_idx],
chunk_outputs_non_tensor[region_idx],
search_chunk,
)
)
within_chunk_region = False
node_idx += 1
......@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
eval_mem: bool = False) -> None:
def __init__(
self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
eval_mem: bool = False,
) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
......@@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE:
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
......@@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE:
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
......@@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE:
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
......@@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE:
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})")
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
and node.args[1].isidentifier() and len(node.args) == 2):
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return
body.append(
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
......@@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
self.eval_mem)
emit_code_with_chunk(
body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
......
import copy
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import NodeMgr, get_node_shape, is_non_memory_node
......@@ -62,12 +59,9 @@ class EstimateMemory(object):
delete_node_dict[node] = max(node_user_idx)
return delete_node_dict
def _remove_deactive_node(self,
user_idx: int,
user: Node,
active_nodes: List,
delete_node_dict: List,
kept_nodes: List = None) -> None:
def _remove_deactive_node(
self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None
) -> None:
"""
remove deactivate nodes from active nodes
"""
......@@ -169,7 +163,7 @@ class EstimateMemory(object):
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem
chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_all = []
if use_chunk:
......@@ -184,7 +178,6 @@ class EstimateMemory(object):
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_mgr.get_node_list()):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
......@@ -193,8 +186,9 @@ class EstimateMemory(object):
# determine chunk ratio for current node
if chunk_within:
chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
chunk_sizes[chunk_region_idx])
chunk_ratio = self._get_chunk_ratio(
node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx]
)
# add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio)
......@@ -222,7 +216,7 @@ class EstimateMemory(object):
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None
......
......@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
......@@ -121,8 +121,10 @@ class SearchChunk(object):
# check if peak node already in chunk info
if chunk_regions is not None:
for i in chunk_regions:
if i["region"][0] < peak_region[0] <= i["region"][1] or \
i["region"][0] < peak_region[1] <= i["region"][1]:
if (
i["region"][0] < peak_region[0] <= i["region"][1]
or i["region"][0] < peak_region[1] <= i["region"][1]
):
return None
active_node_num = [len(i) for i in active_node]
......@@ -146,9 +148,9 @@ class SearchChunk(object):
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]:
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]:
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
......@@ -171,7 +173,7 @@ class SearchChunk(object):
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed
if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
......@@ -180,8 +182,9 @@ class SearchChunk(object):
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
end_idx):
if not self.trace_flow.check_region_start_end(
start_node, start_dim, start_idx, end_node, end_dim, end_idx
):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
......@@ -203,7 +206,7 @@ class SearchChunk(object):
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes
input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {}
for arg in n.args:
......@@ -215,7 +218,8 @@ class SearchChunk(object):
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)):
self.node_mgr.get_node_by_idx(end_idx)
):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
......@@ -279,15 +283,18 @@ class SearchChunk(object):
chunk_infos.append(chunk_info)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos)
self.node_mgr.get_node_list(), chunk_infos
)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
get_logger().info(
"AutoChunk find chunk region %d = (%d, %d)"
% (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])
)
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
chunk_infos,
print_mem=True)
self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos, print_mem=True
)
return chunk_infos
......@@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object):
def __init__(
self,
trace_indice: TraceIndice,
......@@ -20,7 +19,7 @@ class SelectChunk(object):
self.node_mgr = node_mgr
if max_memory is not None:
self.stratge = "fit_memory"
self.max_memory = max_memory # MB
self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
......@@ -57,16 +56,18 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
......@@ -90,13 +91,15 @@ class SelectChunk(object):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
chunk_infos)
chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
)
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
......@@ -109,9 +112,10 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
......@@ -139,8 +143,10 @@ class SelectChunk(object):
return None
# get max possible chunk region
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
max_possible_chunk_region = (
min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]),
)
# get mem for chunk region
regions_dict_list = []
......@@ -149,15 +155,17 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict_list.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
regions_dict_list.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
......@@ -175,7 +183,9 @@ class SelectChunk(object):
return False
for i in chunk_infos:
region = i["region"]
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
(chunk_region_start < region[0] and chunk_region_end < region[0])):
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False
return True
......@@ -16,7 +16,6 @@ from .utils import (
class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
......@@ -151,7 +150,7 @@ class TraceFlow(object):
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
......@@ -266,7 +265,7 @@ class TraceFlow(object):
maybe_prepose_nodes.sort(
key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True,
) # from last node to first node
) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
......@@ -328,7 +327,8 @@ class TraceFlow(object):
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
)
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
......@@ -371,8 +371,9 @@ class TraceFlow(object):
return chunk_info
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
chunk_info: Dict):
def _get_other_output_info(
self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
......@@ -384,8 +385,8 @@ class TraceFlow(object):
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
......@@ -421,7 +422,8 @@ class TraceFlow(object):
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
)
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
......@@ -443,8 +445,11 @@ class TraceFlow(object):
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
reshape_args[0].meta['fwd_out']) > 1:
if (
len(reshape_args) == 1
and get_node_shape(reshape_args[0]) is None
and len(reshape_args[0].meta["fwd_out"]) > 1
):
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
......@@ -462,16 +467,17 @@ class TraceFlow(object):
chunk_info["reshape_size"] = reshape_size
return chunk_info
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
def check_region_start_end(
self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
return False
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
return False
# must have users
if len(end_node.users) == 0:
......
import copy
from typing import Dict, List, Tuple
from typing import Dict, List
from torch.fx.node import Node
......@@ -412,7 +412,7 @@ class TraceIndice(object):
node_idx (int)
"""
# get conv input
assert node.kwargs['size'] is None
assert node.kwargs["size"] is None
assert len(get_node_shape(node)) == 4
# assign index
......@@ -826,7 +826,7 @@ class TraceIndice(object):
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
......@@ -876,10 +876,24 @@ class TraceIndice(object):
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(n == node_name for n in [
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
"sin", "cos"
]):
elif any(
n == node_name
for n in [
"mul",
"add",
"sigmoid",
"relu",
"sub",
"truediv",
"pow",
"dropout",
"where",
"tanh",
"exp",
"sin",
"cos",
]
):
self._assign_elementwise_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
......@@ -920,7 +934,7 @@ class TraceIndice(object):
else:
raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param
self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else:
......
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from typing import Any, Dict, List, Union
from torch.fx.node import Node
......@@ -10,7 +10,6 @@ logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
......@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
if (input_node not in nodes and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)):
if (
input_node not in nodes
and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)
):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
if (output_node not in nodes and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)):
if (
output_node not in nodes
and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)
):
output_nodes.append(node)
return input_nodes, output_nodes
......@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
node.meta['fwd_out'][0], int):
elif (
len(node.meta["fwd_out"]) > 0
and isinstance(node.meta["fwd_out"], list)
and isinstance(node.meta["fwd_out"][0], int)
):
out.append(node)
return out
import torch
import torch.nn as nn
__all__ = ['Accelerator']
__all__ = ["Accelerator"]
_supported_devices = [
'cpu',
'cuda',
"cpu",
"cuda",
# To be supported
# 'xpu',
# 'npu',
......@@ -25,21 +24,22 @@ class Accelerator:
def __init__(self, device: str):
self.device = device
assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
assert (
self.device in _supported_devices
), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
def bind(self):
"""
Set the default device for the current process.
"""
if self.device == 'cpu':
if self.device == "cpu":
pass
elif self.device == 'cuda':
elif self.device == "cuda":
# TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed:
# local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
torch.cuda.set_device(torch.device('cuda'))
pass
torch.cuda.set_device(torch.device("cuda"))
else:
raise ValueError(f"Device {self.device} is not supported yet")
......
......@@ -16,7 +16,7 @@ from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase
__all__ = ['Booster']
__all__ = ["Booster"]
class Booster:
......@@ -60,28 +60,31 @@ class Booster:
plugin (Plugin): The plugin to run the training. Default: None.
"""
def __init__(self,
device: Optional[str] = None,
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
plugin: Optional[Plugin] = None) -> None:
def __init__(
self,
device: Optional[str] = None,
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
plugin: Optional[Plugin] = None,
) -> None:
if plugin is not None:
assert isinstance(
plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
if device is not None:
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
else:
device = device or 'cuda'
device = device or "cuda"
self.accelerator = Accelerator(device)
# set precision
if self.plugin and self.plugin.control_precision():
if mixed_precision is not None:
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
......@@ -95,7 +98,7 @@ class Booster:
self.mixed_precision = mixed_precision
else:
raise ValueError(
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}."
)
if self.plugin is not None and self.plugin.control_checkpoint_io():
......@@ -131,7 +134,8 @@ class Booster:
# transform model for mixed precision
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler)
model, optimizer, criterion, dataloader, lr_scheduler
)
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
......@@ -154,13 +158,15 @@ class Booster:
# TODO(frank lee): implement this method with plugin
optimizer.backward(loss)
def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[Optimizer] = None,
return_loss: bool = True,
return_outputs: bool = False) -> Dict[str, Any]:
def execute_pipeline(
self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[Optimizer] = None,
return_loss: bool = True,
return_outputs: bool = False,
) -> Dict[str, Any]:
"""
Execute forward & backward when utilizing pipeline parallel.
Return loss or Huggingface style model outputs if needed.
......@@ -185,8 +191,9 @@ class Booster:
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
"""
assert isinstance(self.plugin,
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
assert isinstance(
self.plugin, PipelinePluginBase
), f"The plugin {self.plugin.__class__.__name__} does not support pipeline."
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
......@@ -200,8 +207,10 @@ class Booster:
Returns:
contextmanager: Context to disable gradient synchronization.
"""
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
assert (
self.plugin is not None
), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync."
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
return self.plugin.no_sync(model, optimizer)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
......@@ -217,14 +226,16 @@ class Booster:
"""
self.checkpoint_io.load_model(model, checkpoint, strict)
def save_model(self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False) -> None:
def save_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""Save model to checkpoint.
Args:
......@@ -239,13 +250,15 @@ class Booster:
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
"""
self.checkpoint_io.save_model(model,
checkpoint=checkpoint,
shard=shard,
gather_dtensor=gather_dtensor,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors)
self.checkpoint_io.save_model(
model,
checkpoint=checkpoint,
shard=shard,
gather_dtensor=gather_dtensor,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""Load optimizer from checkpoint.
......@@ -260,13 +273,15 @@ class Booster:
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024) -> None:
def save_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
) -> None:
"""
Save optimizer to checkpoint.
......
......@@ -6,16 +6,22 @@ from .fp16_torch import FP16TorchMixedPrecision
from .mixed_precision_base import MixedPrecision
__all__ = [
'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision'
"MixedPrecision",
"mixed_precision_factory",
"FP16_Apex_MixedPrecision",
"FP16_Torch_MixedPrecision",
"FP32_MixedPrecision",
"BF16_MixedPrecision",
"FP8_MixedPrecision",
"FP16NaiveMixedPrecision",
]
_mixed_precision_mapping = {
'fp16': FP16TorchMixedPrecision,
'fp16_apex': FP16ApexMixedPrecision,
'fp16_naive': FP16NaiveMixedPrecision,
'bf16': BF16MixedPrecision,
'fp8': FP8MixedPrecision
"fp16": FP16TorchMixedPrecision,
"fp16_apex": FP16ApexMixedPrecision,
"fp16_naive": FP16NaiveMixedPrecision,
"bf16": BF16MixedPrecision,
"fp8": FP8MixedPrecision,
}
......@@ -31,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
return _mixed_precision_mapping[mixed_precision_type]()
else:
raise ValueError(
f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
)
......@@ -23,16 +23,18 @@ class FP16ApexMixedPrecision(MixedPrecision):
max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.
"""
def __init__(self,
opt_level: Optional[str] = "O1",
cast_model_type: torch.dtype = None,
patch_torch_functions: bool = None,
keep_batchnorm_fp32: Union[bool, str] = None,
master_weights: bool = None,
loss_scale: Union[float, str] = None,
cast_model_outputs: Any = None,
num_losses: Optional[int] = 1,
verbosity: int = 1,
min_loss_scale: float = None,
max_loss_scale: float = 2.**24) -> None:
def __init__(
self,
opt_level: Optional[str] = "O1",
cast_model_type: torch.dtype = None,
patch_torch_functions: bool = None,
keep_batchnorm_fp32: Union[bool, str] = None,
master_weights: bool = None,
loss_scale: Union[float, str] = None,
cast_model_outputs: Any = None,
num_losses: Optional[int] = 1,
verbosity: int = 1,
min_loss_scale: float = None,
max_loss_scale: float = 2.0**24,
) -> None:
pass
......@@ -15,12 +15,14 @@ class FP16NaiveMixedPrecision(MixedPrecision):
verbose(bool): if set to `True`, will print debug info.
"""
def __init__(self,
log_num_zeros_in_grad: bool,
initial_scale: int,
growth_factor: int,
backoff_factor: float,
hysteresis: int,
max_scale: int,
verbose: bool = None) -> None:
def __init__(
self,
log_num_zeros_in_grad: bool,
initial_scale: int,
growth_factor: int,
backoff_factor: float,
hysteresis: int,
max_scale: int,
verbose: bool = None,
) -> None:
pass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .mixed_precision_base import MixedPrecision
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
class TorchAMPOptimizer(OptimizerWrapper):
......@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper):
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(self,
optim: Optimizer,
init_scale: float = 2.**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000) -> None:
def __init__(
self,
optim: Optimizer,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
) -> None:
super().__init__(optim)
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval)
self.scaler = torch.cuda.amp.GradScaler(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
)
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
......@@ -60,12 +64,14 @@ class TorchAMPOptimizer(OptimizerWrapper):
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> None:
def clip_grad_by_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs,
) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
......@@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision):
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(self,
init_scale: float = 2.**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000) -> None:
def __init__(
self,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
) -> None:
super().__init__()
self.torch_amp_kwargs = dict(init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval)
def configure(self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
self.torch_amp_kwargs = dict(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
)
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
......
......@@ -4,11 +4,12 @@ from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
from .torch_fsdp_plugin import TorchFSDPPlugin
__all__.append('TorchFSDPPlugin')
__all__.append("TorchFSDPPlugin")
......@@ -10,25 +10,19 @@ from .plugin_base import Plugin
class DPPluginBase(Plugin):
"""This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.
"""
"""This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation."""
def __init__(self) -> None:
super().__init__()
assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
assert (
dist.is_initialized()
), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment"
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
def prepare_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
......@@ -60,11 +54,13 @@ class DPPluginBase(Plugin):
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
......@@ -27,14 +27,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase
__all__ = ['GeminiPlugin']
__all__ = ["GeminiPlugin"]
SUPPORTED_PRECISION = ['fp16', 'bf16']
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
......@@ -74,13 +73,15 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_sharded_model(self,
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
def save_sharded_model(
self,
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
"""
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
......@@ -97,34 +98,37 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_model(self,
model: GeminiDDP,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False):
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
):
"""
Load shard model, load model from multiple files.
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
......@@ -153,20 +157,24 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=is_master,
use_safetensors=False)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=is_master,
use_safetensors=False,
)
# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
......@@ -185,8 +193,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Load param_groups.
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory.')
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_param_groups = torch.load(param_group_path)
optimizer.load_param_groups(saved_param_groups)
......@@ -274,11 +284,11 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
......@@ -300,7 +310,7 @@ class GeminiPlugin(DPPluginBase):
verbose: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
......@@ -319,16 +329,20 @@ class GeminiPlugin(DPPluginBase):
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
)
self.optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
)
self.verbose = verbose
def support_no_sync(self) -> bool:
......@@ -344,7 +358,7 @@ class GeminiPlugin(DPPluginBase):
return True
def supported_devices(self) -> List[str]:
return ['cuda']
return ["cuda"]
def configure(
self,
......@@ -354,7 +368,6 @@ class GeminiPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
......@@ -368,13 +381,10 @@ class GeminiPlugin(DPPluginBase):
# wrap the model with Gemini
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(optimizer,
model.unwrap(),
**self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
)
return model, optimizer, criterion, dataloader, lr_scheduler
......
......@@ -37,10 +37,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class HybridParallelModule(ModelWrapper):
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict, custom_policy: Policy) -> None:
def __init__(
self,
module: Module,
precision: str,
shard_config: ShardConfig,
dp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
......@@ -54,13 +60,14 @@ class HybridParallelModule(ModelWrapper):
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
)
# setting mixed_precision
self.mixed_precision = None
if precision == 'fp16':
if precision == "fp16":
self.mixed_precision = torch.float16
elif precision == 'bf16':
elif precision == "bf16":
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
......@@ -123,22 +130,21 @@ def get_param_info(optim: Optimizer):
if optim is None:
return {}
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
packed_group["params"] = []
packed_group = {k: v for k, v in group.items() if k != 'params'}
packed_group['params'] = []
for param_id, param in enumerate(group['params'], start_index):
for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group['params'].append(param_id)
param_info['param2id'][id(param)] = param_id
param_info['id2param'][param_id] = id(param)
param_info['param2shape'][id(param)] = original_shape
packed_group["params"].append(param_id)
param_info["param2id"][id(param)] = param_id
param_info["id2param"][param_id] = id(param)
param_info["param2shape"][id(param)] = original_shape
param_info['param_groups'].append(packed_group)
start_index += len(group['params'])
param_info["param_groups"].append(packed_group)
start_index += len(group["params"])
return param_info
......@@ -147,13 +153,12 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group['params'] if p in model_params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})
params = [p for p in group["params"] if p in model_params]
new_param_groups.append({**group, "params": params})
optim.__setstate__({"param_groups": new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline:
......@@ -162,60 +167,87 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(self,
optim: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0):
def __init__(
self,
optim: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = "fp16",
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
hysteresis, max_scale, max_norm)
super().__init__(
optim,
precision,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
max_norm,
)
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
backoff_factor: float = .5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None,
):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
forced_dtype)
super().__init__(
optimizer,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
clip_grad_norm,
verbose,
reduce_bucket_size,
communication_dtype,
overlap_communication,
partition_grad,
cpu_offload,
dp_process_group,
tp_process_group,
forced_dtype,
)
class HybridParallelPlugin(PipelinePluginBase):
......@@ -276,46 +308,47 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
"""
def __init__(self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
custom_policy: Policy = None) -> None:
def __init__(
self,
tp_size: int,
pp_size: int,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
custom_policy: Policy = None,
) -> None:
super().__init__()
assert dist.get_world_size() % (
tp_size * pp_size
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
self.tp_size = tp_size
self.pp_size = pp_size
......@@ -334,24 +367,28 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
......@@ -362,18 +399,22 @@ class HybridParallelPlugin(PipelinePluginBase):
max_scale=max_scale,
)
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)
self.ddp_config = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2))
self.zero_config = dict(
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
)
self.max_norm = max_norm
......@@ -382,10 +423,10 @@ class HybridParallelPlugin(PipelinePluginBase):
return self.pp_size > 1
def supported_devices(self) -> List[str]:
return ['cuda']
return ["cuda"]
def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16', 'fp32']
return ["fp16", "bf16", "fp32"]
def control_device(self) -> bool:
return True
......@@ -410,57 +451,67 @@ class HybridParallelPlugin(PipelinePluginBase):
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config, self.custom_policy)
model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:
optimizer = HybridParallelAMPOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
optimizer.master_to_working_map)
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else:
optimizer = HybridParallelNaiveOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
optimizer._param_store.master_to_working_param)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
)
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self,
data_iter: Iterator,
model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
HybridParallelZeroOptimizer]] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
def execute_pipeline(
self,
data_iter: Iterator,
model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[
Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer]
] = None,
return_loss: bool = True,
return_outputs: bool = False,
) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
# return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
return_outputs)
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
......@@ -468,15 +519,9 @@ class HybridParallelPlugin(PipelinePluginBase):
model.sync_grads()
return outputs
def prepare_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
......@@ -499,10 +544,9 @@ class HybridParallelPlugin(PipelinePluginBase):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset,
num_replicas=self.pg_mesh.size(DP_AXIS),
rank=self.pg_mesh.coordinate(DP_AXIS),
shuffle=shuffle)
sampler = DistributedSampler(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
......@@ -511,14 +555,16 @@ class HybridParallelPlugin(PipelinePluginBase):
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
......
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
......@@ -33,7 +31,7 @@ from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
__all__ = ['LowLevelZeroPlugin']
__all__ = ["LowLevelZeroPlugin"]
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
......@@ -42,17 +40,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
return x
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
if precision == "fp16":
self.dtype = torch.float16
elif precision == 'bf16':
elif precision == "bf16":
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
......@@ -74,7 +71,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
"""Save optimizer to checkpoint but only on master process.
......@@ -91,12 +87,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024,
):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
......@@ -148,9 +146,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file.
......@@ -170,8 +170,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory.')
raise RuntimeError(
f"Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory."
)
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
......@@ -181,9 +183,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self.coordinator.world_size -
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
if isinstance(v, torch.Tensor) and k != "step":
padding_size = (
self.coordinator.world_size - v.numel() % self.coordinator.world_size
) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
......@@ -194,33 +197,39 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
use_safetensors: bool):
def save_unsharded_model(
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
use_safetensors)
super().save_sharded_model(
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
model.update_master_params()
def load_sharded_model(self,
model: LowLevelZeroModel,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
def load_sharded_model(
self,
model: LowLevelZeroModel,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
......@@ -264,7 +273,7 @@ class LowLevelZeroPlugin(DPPluginBase):
def __init__(
self,
stage: int = 1,
precision: str = 'fp16',
precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
......@@ -281,9 +290,9 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose: bool = False,
) -> None:
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
self.stage = stage
self.precision = precision
self.zero_optim_kwargs = dict(
......@@ -319,7 +328,7 @@ class LowLevelZeroPlugin(DPPluginBase):
return True
def supported_devices(self) -> List[str]:
return ['cuda']
return ["cuda"]
def configure(
self,
......@@ -329,15 +338,13 @@ class LowLevelZeroPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
**self.zero_optim_kwargs,
verbose=self.verbose)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
......
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
......@@ -9,11 +9,10 @@ from torch.utils.data import DataLoader, Dataset
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
__all__ = ['Plugin']
__all__ = ["Plugin"]
class Plugin(ABC):
@abstractmethod
def supported_devices(self) -> List[str]:
pass
......@@ -51,33 +50,31 @@ class Plugin(ABC):
"""
Whether the plugin controls the checkpoint io
"""
pass
@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
pass
@abstractmethod
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
pass
@abstractmethod
def prepare_dataloader(self,
dataset: Dataset,
batch_size: int,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
**kwargs):
def prepare_dataloader(
self,
dataset: Dataset,
batch_size: int,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
**kwargs,
):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader`
"""
pass
......@@ -9,13 +9,14 @@ from .plugin_base import Plugin
class PipelinePluginBase(Plugin):
@abstractmethod
def execute_pipeline(self,
data_iter: Iterator,
model: ModelWrapper,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = True,
return_outputs: bool = False) -> dict:
def execute_pipeline(
self,
data_iter: Iterator,
model: ModelWrapper,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = True,
return_outputs: bool = False,
) -> dict:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment