Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL ...@@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta() AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
...@@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out ...@@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
for i in range(len(chunk_output)): for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i]))) shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) in ["split", "unbind"]: if get_node_name(chunk_output[i]) in ["split", "unbind"]:
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name, tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
input_node.name) shape_str,
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta']) input_node.name,
input_node.name,
)
tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
tensor_str = "[" + tensor_str[:-2] + "]" tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str) context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else: else:
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str, context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
input_node.name, input_node.name) chunk_output[i].name,
shape_str,
input_node.name,
input_node.name,
)
out_shape = get_node_shape(chunk_output[0]) out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_output_dim[0]] chunk_shape = out_shape[chunk_output_dim[0]]
...@@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out ...@@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out
return context return context
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node], def _gen_loop_end(
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str: chunk_inputs: List[Node],
chunk_non_compute_inputs: List[Node],
node_list: List[Node],
chunk_outputs_idx: int,
chunk_outputs_non_tensor: List[Node],
search_chunk: SearchChunk,
) -> str:
""" """
Generate chunk loop end Generate chunk loop end
...@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape( ...@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1: if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0] source_node = meta_node.args[0].args[0]
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"] if (
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None): source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node)) chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice) body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body return body
...@@ -203,11 +229,12 @@ def _add_node_slice( ...@@ -203,11 +229,12 @@ def _add_node_slice(
# outputs node # outputs node
else: else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]): if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", chunk_slice = _gen_chunk_slice_dim(
get_node_shape(chunk_node)) chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
)
if get_node_name(chunk_node) in ["split", "unbind"]: if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = "" split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])): for i in range(len(chunk_node.meta["tensor_meta"])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice) split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2] split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice) body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
...@@ -216,13 +243,15 @@ def _add_node_slice( ...@@ -216,13 +243,15 @@ def _add_node_slice(
return body return body
def emit_code_with_chunk(body: List[str], def emit_code_with_chunk(
nodes: Iterable[Node], body: List[str],
emit_node_func: Callable, nodes: Iterable[Node],
delete_unused_value_func: Callable, emit_node_func: Callable,
search_chunk: SearchChunk, delete_unused_value_func: Callable,
chunk_infos: List, search_chunk: SearchChunk,
eval_mem: bool = False): chunk_infos: List,
eval_mem: bool = False,
):
""" """
Emit code with chunk according to chunk_infos. Emit code with chunk according to chunk_infos.
...@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str], ...@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str],
chunk_ends = [i["region"][1] for i in chunk_infos] chunk_ends = [i["region"][1] for i in chunk_infos]
# chunk inputs # chunk inputs
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs # chunk outputs
...@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str], ...@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
chunk_outputs[region_idx], chunk_outputs[region_idx],
chunk_outputs_dim[region_idx], chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"], chunk_infos[region_idx]["chunk_size"],
)) )
)
if within_chunk_region: if within_chunk_region:
emit_node_func(node, body) emit_node_func(node, body)
...@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str], ...@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
if eval_mem: if eval_mem:
body.append( body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" " if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name)) % (node.name)
)
else: else:
emit_node_func(node, body) emit_node_func(node, body)
if node_idx not in chunk_inputs: if node_idx not in chunk_inputs:
...@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str], ...@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
if eval_mem: if eval_mem:
body.append( body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" "print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name)) % (node.name)
)
# generate chunk region end # generate chunk region end
if node_idx in chunk_ends: if node_idx in chunk_ends:
body.append( body.append(
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list, _gen_loop_end(
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk)) chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
node_list,
chunk_ends[region_idx],
chunk_outputs_non_tensor[region_idx],
search_chunk,
)
)
within_chunk_region = False within_chunk_region = False
node_idx += 1 node_idx += 1
...@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str], ...@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen): class AutoChunkCodeGen(CodeGen):
def __init__(
def __init__(self, self,
meta_graph, meta_graph,
max_memory: int = None, max_memory: int = None,
print_mem: bool = False, print_mem: bool = False,
print_progress: bool = False, print_progress: bool = False,
eval_mem: bool = False) -> None: eval_mem: bool = False,
) -> None:
super().__init__() super().__init__()
self.eval_mem = eval_mem self.eval_mem = eval_mem
# find the chunk regions # find the chunk regions
...@@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE: ...@@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE:
Returns: the global name that should be used to reference 'obj' in generated source. Returns: the global name that should be used to reference 'obj' in generated source.
""" """
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We # HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their # can't import them like normal modules so they must retain their
# fully qualified name. # fully qualified name.
...@@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE: ...@@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE:
return add_global(typename, o) return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg): def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global. # Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"): if isinstance(arg, tuple) and hasattr(arg, "_fields"):
...@@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE: ...@@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE:
# NOTE: we add a variable to distinguish body and ckpt_func # NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body): def emit_node(node: Node, body):
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}") maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder": if node.op == "placeholder":
assert isinstance(node.target, str) assert isinstance(node.target, str)
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}") maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "") raw_name = node.target.replace("*", "")
if raw_name != repr(node): if raw_name != repr(node):
...@@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE: ...@@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE:
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append( body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})") f"({_format_args(node.args[1:], node.kwargs)})"
)
return return
elif node.op == "call_function": elif node.op == "call_function":
assert callable(node.target) assert callable(node.target)
# pretty print operators # pretty print operators
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods): if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple) assert isinstance(node.args, tuple)
body.append(f"{repr(node)}{maybe_type_annotation} = " body.append(
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}") f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return return
# pretty print inplace operators; required for jit.script to work properly # pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo # not currently supported in normal FX graphs, but generated by torchdynamo
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods): if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " body.append(
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}") f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return return
qualified_name = _get_qualified_name(node.target) qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target) global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument # special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str) if (
and node.args[1].isidentifier() and len(node.args) == 2): global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append( body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}") f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return return
body.append( body.append(
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})") f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False): if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name) wrapped_fns.setdefault(global_name)
return return
elif node.op == "call_module": elif node.op == "call_module":
assert isinstance(node.target, str) assert isinstance(node.target, str)
body.append(f"{repr(node)}{maybe_type_annotation} = " body.append(
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})") f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return return
elif node.op == "get_attr": elif node.op == "get_attr":
assert isinstance(node.target, str) assert isinstance(node.target, str)
...@@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE: ...@@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we # if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen # will use nested type of activation checkpoint codegen
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, emit_code_with_chunk(
self.eval_mem) body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
)
if len(body) == 0: if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body
......
import copy from typing import Dict, List
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import NodeMgr, get_node_shape, is_non_memory_node from .utils import NodeMgr, get_node_shape, is_non_memory_node
...@@ -62,12 +59,9 @@ class EstimateMemory(object): ...@@ -62,12 +59,9 @@ class EstimateMemory(object):
delete_node_dict[node] = max(node_user_idx) delete_node_dict[node] = max(node_user_idx)
return delete_node_dict return delete_node_dict
def _remove_deactive_node(self, def _remove_deactive_node(
user_idx: int, self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None
user: Node, ) -> None:
active_nodes: List,
delete_node_dict: List,
kept_nodes: List = None) -> None:
""" """
remove deactivate nodes from active nodes remove deactivate nodes from active nodes
""" """
...@@ -169,7 +163,7 @@ class EstimateMemory(object): ...@@ -169,7 +163,7 @@ class EstimateMemory(object):
use_chunk = True if chunk_infos is not None else False use_chunk = True if chunk_infos is not None else False
chunk_within = False chunk_within = False
chunk_region_idx = None chunk_region_idx = None
chunk_ratio = 1 # use it to estimate chunk mem chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_all = [] chunk_inputs_all = []
if use_chunk: if use_chunk:
...@@ -184,7 +178,6 @@ class EstimateMemory(object): ...@@ -184,7 +178,6 @@ class EstimateMemory(object):
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_mgr.get_node_list()): for idx, node in enumerate(node_mgr.get_node_list()):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor # if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts: if use_chunk and idx in chunk_starts:
chunk_within = True chunk_within = True
...@@ -193,8 +186,9 @@ class EstimateMemory(object): ...@@ -193,8 +186,9 @@ class EstimateMemory(object):
# determine chunk ratio for current node # determine chunk ratio for current node
if chunk_within: if chunk_within:
chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx], chunk_ratio = self._get_chunk_ratio(
chunk_sizes[chunk_region_idx]) node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx]
)
# add current node as active node # add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio) self._add_active_node(node, active_nodes, chunk_ratio)
...@@ -222,7 +216,7 @@ class EstimateMemory(object): ...@@ -222,7 +216,7 @@ class EstimateMemory(object):
# if node in chunk end nodes, restore chunk settings # if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends: if use_chunk and idx in chunk_ends:
self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
chunk_within = False chunk_within = False
chunk_ratio = 1 chunk_ratio = 1
chunk_region_idx = None chunk_region_idx = None
......
...@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph ...@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk from .select_chunk import SelectChunk
from .trace_flow import TraceFlow from .trace_flow import TraceFlow
from .trace_indice import TraceIndice from .trace_indice import TraceIndice
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object): class SearchChunk(object):
...@@ -121,8 +121,10 @@ class SearchChunk(object): ...@@ -121,8 +121,10 @@ class SearchChunk(object):
# check if peak node already in chunk info # check if peak node already in chunk info
if chunk_regions is not None: if chunk_regions is not None:
for i in chunk_regions: for i in chunk_regions:
if i["region"][0] < peak_region[0] <= i["region"][1] or \ if (
i["region"][0] < peak_region[1] <= i["region"][1]: i["region"][0] < peak_region[0] <= i["region"][1]
or i["region"][0] < peak_region[1] <= i["region"][1]
):
return None return None
active_node_num = [len(i) for i in active_node] active_node_num = [len(i) for i in active_node]
...@@ -146,9 +148,9 @@ class SearchChunk(object): ...@@ -146,9 +148,9 @@ class SearchChunk(object):
region = i["region"] region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]: if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]:
chunk_region_start = region[1] + 1 chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]:
chunk_region_end = region[0] - 1 chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end return chunk_region_start, chunk_region_end
...@@ -171,7 +173,7 @@ class SearchChunk(object): ...@@ -171,7 +173,7 @@ class SearchChunk(object):
chunk_infos: possible regions found chunk_infos: possible regions found
""" """
start_traces = input_trace[start_idx] start_traces = input_trace[start_idx]
if len(start_traces) > 1: # TODO need to be removed if len(start_traces) > 1: # TODO need to be removed
return [] return []
end_trace = output_trace[end_idx] end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx) end_node = self.node_mgr.get_node_by_idx(end_idx)
...@@ -180,8 +182,9 @@ class SearchChunk(object): ...@@ -180,8 +182,9 @@ class SearchChunk(object):
for end_dim, _ in enumerate(end_trace["indice"]): for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items(): for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]): for start_dim, _ in enumerate(start_trace["indice"]):
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, if not self.trace_flow.check_region_start_end(
end_idx): start_node, start_dim, start_idx, end_node, end_dim, end_idx
):
continue continue
# flow search # flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim) chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
...@@ -203,7 +206,7 @@ class SearchChunk(object): ...@@ -203,7 +206,7 @@ class SearchChunk(object):
""" """
possible_chunk_region = [] possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list) output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
input_trace = [] # trace of a node's input nodes input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_mgr.get_node_list()): for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {} cur_trace = {}
for arg in n.args: for arg in n.args:
...@@ -215,7 +218,8 @@ class SearchChunk(object): ...@@ -215,7 +218,8 @@ class SearchChunk(object):
for end_idx in range(peak_region[1], max_chunk_region[1] + 1): for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes # skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node( if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)): self.node_mgr.get_node_by_idx(end_idx)
):
continue continue
# select free dim # select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
...@@ -279,15 +283,18 @@ class SearchChunk(object): ...@@ -279,15 +283,18 @@ class SearchChunk(object):
chunk_infos.append(chunk_info) chunk_infos.append(chunk_info)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem( mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos) self.node_mgr.get_node_list(), chunk_infos
)
if self.print_progress: if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % get_logger().info(
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) "AutoChunk find chunk region %d = (%d, %d)"
% (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])
)
if self.print_mem: if self.print_mem:
self.print_mem = False self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), self.estimate_memory.estimate_chunk_inference_mem(
chunk_infos, self.node_mgr.get_node_list(), chunk_infos, print_mem=True
print_mem=True) )
return chunk_infos return chunk_infos
...@@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node ...@@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object): class SelectChunk(object):
def __init__( def __init__(
self, self,
trace_indice: TraceIndice, trace_indice: TraceIndice,
...@@ -20,7 +19,7 @@ class SelectChunk(object): ...@@ -20,7 +19,7 @@ class SelectChunk(object):
self.node_mgr = node_mgr self.node_mgr = node_mgr
if max_memory is not None: if max_memory is not None:
self.stratge = "fit_memory" self.stratge = "fit_memory"
self.max_memory = max_memory # MB self.max_memory = max_memory # MB
else: else:
self.stratge = "min_memory" self.stratge = "min_memory"
...@@ -57,16 +56,18 @@ class SelectChunk(object): ...@@ -57,16 +56,18 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region] cur_chunk_infos = chunk_infos + [cur_region]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1] cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak) cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory: if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({ regions_dict.append(
"chunk_info": region, {
"chunk_max_mem": cur_chunk_region_max_peak, "chunk_info": region,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), "chunk_max_mem": cur_chunk_region_max_peak,
"reorder_chunk_info": cur_region, "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_node_list": cur_node_list, "reorder_chunk_info": cur_region,
}) "reorder_node_list": cur_node_list,
}
)
# no region found # no region found
if len(regions_dict) == 0: if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.") raise RuntimeError("Search failed. Try a larger memory threshold.")
...@@ -90,13 +91,15 @@ class SelectChunk(object): ...@@ -90,13 +91,15 @@ class SelectChunk(object):
chunk_size *= 2 chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info] cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_chunk_infos)[0] chunk_region_dict["reorder_node_list"], cur_chunk_infos
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1]) )[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
# search exact size # search exact size
chunk_info = chunk_region_dict["chunk_info"] chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict, chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_infos) chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
)
return chunk_info return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
...@@ -109,9 +112,10 @@ class SelectChunk(object): ...@@ -109,9 +112,10 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5) mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info] cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_chunk_infos)[0] chunk_region_dict["reorder_node_list"], cur_chunk_infos
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1]) )[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory: if cur_chunk_max_mem >= self.max_memory:
right = mid - gap right = mid - gap
else: else:
...@@ -139,8 +143,10 @@ class SelectChunk(object): ...@@ -139,8 +143,10 @@ class SelectChunk(object):
return None return None
# get max possible chunk region # get max possible chunk region
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]), max_possible_chunk_region = (
max([i["region"][1] for i in possible_chunk_regions])) min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]),
)
# get mem for chunk region # get mem for chunk region
regions_dict_list = [] regions_dict_list = []
...@@ -149,15 +155,17 @@ class SelectChunk(object): ...@@ -149,15 +155,17 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region] cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak) cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict_list.append({ regions_dict_list.append(
"chunk_info": region, {
"chunk_max_mem": cur_chunk_region_max_peak, "chunk_info": region,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), "chunk_max_mem": cur_chunk_region_max_peak,
"reorder_chunk_info": cur_region, "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_node_list": cur_node_list, "reorder_chunk_info": cur_region,
}) "reorder_node_list": cur_node_list,
}
)
# select the min mem # select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list] chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
...@@ -175,7 +183,9 @@ class SelectChunk(object): ...@@ -175,7 +183,9 @@ class SelectChunk(object):
return False return False
for i in chunk_infos: for i in chunk_infos:
region = i["region"] region = i["region"]
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or if not (
(chunk_region_start < region[0] and chunk_region_end < region[0])): (chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False return False
return True return True
...@@ -16,7 +16,6 @@ from .utils import ( ...@@ -16,7 +16,6 @@ from .utils import (
class TraceFlow(object): class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice self.trace_indice = trace_indice
self.node_mgr = node_mgr self.node_mgr = node_mgr
...@@ -151,7 +150,7 @@ class TraceFlow(object): ...@@ -151,7 +150,7 @@ class TraceFlow(object):
return True return True
def _get_all_node_info(self, end_dim, start_idx, end_idx): def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0: while len(cur_node_list) > 0:
...@@ -266,7 +265,7 @@ class TraceFlow(object): ...@@ -266,7 +265,7 @@ class TraceFlow(object):
maybe_prepose_nodes.sort( maybe_prepose_nodes.sort(
key=lambda x: self.node_mgr.find_node_idx(x), key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True, reverse=True,
) # from last node to first node ) # from last node to first node
prepose_nodes = [] prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes # set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0: while len(maybe_prepose_nodes) > 0:
...@@ -328,7 +327,8 @@ class TraceFlow(object): ...@@ -328,7 +327,8 @@ class TraceFlow(object):
def flow_search(self, start_idx, start_dim, end_idx, end_dim): def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes( inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)) self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
)
# get every node's chunk dim and fix dim # get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
...@@ -371,8 +371,9 @@ class TraceFlow(object): ...@@ -371,8 +371,9 @@ class TraceFlow(object):
return chunk_info return chunk_info
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, def _get_other_output_info(
chunk_info: Dict): self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
):
start_node = self.node_mgr.get_node_by_idx(start_idx) start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs # loop all outputs
for output in outputs: for output in outputs:
...@@ -384,8 +385,8 @@ class TraceFlow(object): ...@@ -384,8 +385,8 @@ class TraceFlow(object):
# skip non tensor # skip non tensor
if get_node_shape(output) is None: if get_node_shape(output) is None:
# log shape tensor # log shape tensor
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int): if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out']) chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
continue continue
# loop every dim of outputs, try to find a legal one # loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))): for output_dim in range(len(get_node_shape(output))):
...@@ -421,7 +422,8 @@ class TraceFlow(object): ...@@ -421,7 +422,8 @@ class TraceFlow(object):
for k, v in new_all_node_info.items(): for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]: if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list( chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])) set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
)
else: else:
chunk_info["node_chunk_dim"][k] = v chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output) chunk_info["outputs"].append(output)
...@@ -443,8 +445,11 @@ class TraceFlow(object): ...@@ -443,8 +445,11 @@ class TraceFlow(object):
if node.args[0] in chunk_info["inputs_non_chunk"]: if node.args[0] in chunk_info["inputs_non_chunk"]:
continue continue
reshape_args = flat_list(node.args[1:]) reshape_args = flat_list(node.args[1:])
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len( if (
reshape_args[0].meta['fwd_out']) > 1: len(reshape_args) == 1
and get_node_shape(reshape_args[0]) is None
and len(reshape_args[0].meta["fwd_out"]) > 1
):
continue continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = "" new_shape = ""
...@@ -462,16 +467,17 @@ class TraceFlow(object): ...@@ -462,16 +467,17 @@ class TraceFlow(object):
chunk_info["reshape_size"] = reshape_size chunk_info["reshape_size"] = reshape_size
return chunk_info return chunk_info
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, def check_region_start_end(
end_idx: int) -> bool: self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
) -> bool:
""" """
check if region start and end is legal check if region start and end is legal
""" """
# dim cannot be None # dim cannot be None
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None): if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
return False return False
# dim size cannot be 1 # dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
return False return False
# must have users # must have users
if len(end_node.users) == 0: if len(end_node.users) == 0:
......
import copy import copy
from typing import Dict, List, Tuple from typing import Dict, List
from torch.fx.node import Node from torch.fx.node import Node
...@@ -412,7 +412,7 @@ class TraceIndice(object): ...@@ -412,7 +412,7 @@ class TraceIndice(object):
node_idx (int) node_idx (int)
""" """
# get conv input # get conv input
assert node.kwargs['size'] is None assert node.kwargs["size"] is None
assert len(get_node_shape(node)) == 4 assert len(get_node_shape(node)) == 4
# assign index # assign index
...@@ -826,7 +826,7 @@ class TraceIndice(object): ...@@ -826,7 +826,7 @@ class TraceIndice(object):
# clear compute # clear compute
for dim_compute in trace["compute"]: for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1): for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes): if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes:
dim_compute.pop(i) dim_compute.pop(i)
continue continue
# clear source # clear source
...@@ -876,10 +876,24 @@ class TraceIndice(object): ...@@ -876,10 +876,24 @@ class TraceIndice(object):
self._assign_matmul_indice(node, idx) self._assign_matmul_indice(node, idx)
elif "softmax" == node_name: elif "softmax" == node_name:
self._assign_softmax_indice(node, idx) self._assign_softmax_indice(node, idx)
elif any(n == node_name for n in [ elif any(
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp", n == node_name
"sin", "cos" for n in [
]): "mul",
"add",
"sigmoid",
"relu",
"sub",
"truediv",
"pow",
"dropout",
"where",
"tanh",
"exp",
"sin",
"cos",
]
):
self._assign_elementwise_indice(node, idx) self._assign_elementwise_indice(node, idx)
elif "einsum" == node_name: elif "einsum" == node_name:
self._assign_einsum_indice(node, idx) self._assign_einsum_indice(node, idx)
...@@ -920,7 +934,7 @@ class TraceIndice(object): ...@@ -920,7 +934,7 @@ class TraceIndice(object):
else: else:
raise NotImplementedError(node_name, "module not implemented yet!") raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr": elif node.op == "get_attr":
self._assign_all_indice(node, idx) # get param self._assign_all_indice(node, idx) # get param
elif node.op == "output": elif node.op == "output":
continue continue
else: else:
......
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union from typing import Any, Dict, List, Union
from torch.fx.node import Node from torch.fx.node import Node
...@@ -10,7 +10,6 @@ logger = get_dist_logger() ...@@ -10,7 +10,6 @@ logger = get_dist_logger()
class NodeMgr(object): class NodeMgr(object):
def __init__(self, nodes_list: List[Node]) -> None: def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list self._node_list = nodes_list
self._node_dict = {} self._node_dict = {}
...@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, ...@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
# we treat that input node as the input of the checkpoint function # we treat that input node as the input of the checkpoint function
for node in nodes: for node in nodes:
for input_node in node._input_nodes.keys(): for input_node in node._input_nodes.keys():
if (input_node not in nodes and input_node not in input_nodes if (
and not is_non_compute_node_except_placeholder(input_node)): input_node not in nodes
and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)
):
input_nodes.append(input_node) input_nodes.append(input_node)
# if a node has a user node which is not in the node list # if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output # we treat that user node as the node receiving the current node output
for node in nodes: for node in nodes:
for output_node in node.users.keys(): for output_node in node.users.keys():
if (output_node not in nodes and node not in output_nodes if (
and not is_non_compute_node_except_placeholder_output(output_node)): output_node not in nodes
and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)
):
output_nodes.append(node) output_nodes.append(node)
return input_nodes, output_nodes return input_nodes, output_nodes
...@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]: ...@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
for node in node_list: for node in node_list:
if get_node_shape(node) is not None: if get_node_shape(node) is not None:
out.append(node) out.append(node)
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance( elif (
node.meta['fwd_out'][0], int): len(node.meta["fwd_out"]) > 0
and isinstance(node.meta["fwd_out"], list)
and isinstance(node.meta["fwd_out"][0], int)
):
out.append(node) out.append(node)
return out return out
import torch import torch
import torch.nn as nn import torch.nn as nn
__all__ = ['Accelerator'] __all__ = ["Accelerator"]
_supported_devices = [ _supported_devices = [
'cpu', "cpu",
'cuda', "cuda",
# To be supported # To be supported
# 'xpu', # 'xpu',
# 'npu', # 'npu',
...@@ -25,21 +24,22 @@ class Accelerator: ...@@ -25,21 +24,22 @@ class Accelerator:
def __init__(self, device: str): def __init__(self, device: str):
self.device = device self.device = device
assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" assert (
self.device in _supported_devices
), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
def bind(self): def bind(self):
""" """
Set the default device for the current process. Set the default device for the current process.
""" """
if self.device == 'cpu': if self.device == "cpu":
pass pass
elif self.device == 'cuda': elif self.device == "cuda":
# TODO(FrankLeeeee): use global environment to check if it is a dist job # TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed: # if is_distributed:
# local_rank = EnvTable().get_local_rank() # local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}')) # torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
torch.cuda.set_device(torch.device('cuda')) torch.cuda.set_device(torch.device("cuda"))
pass
else: else:
raise ValueError(f"Device {self.device} is not supported yet") raise ValueError(f"Device {self.device} is not supported yet")
......
...@@ -16,7 +16,7 @@ from .mixed_precision import MixedPrecision, mixed_precision_factory ...@@ -16,7 +16,7 @@ from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase from .plugin.pp_plugin_base import PipelinePluginBase
__all__ = ['Booster'] __all__ = ["Booster"]
class Booster: class Booster:
...@@ -60,28 +60,31 @@ class Booster: ...@@ -60,28 +60,31 @@ class Booster:
plugin (Plugin): The plugin to run the training. Default: None. plugin (Plugin): The plugin to run the training. Default: None.
""" """
def __init__(self, def __init__(
device: Optional[str] = None, self,
mixed_precision: Optional[Union[MixedPrecision, str]] = None, device: Optional[str] = None,
plugin: Optional[Plugin] = None) -> None: mixed_precision: Optional[Union[MixedPrecision, str]] = None,
plugin: Optional[Plugin] = None,
) -> None:
if plugin is not None: if plugin is not None:
assert isinstance( assert isinstance(
plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.' plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin self.plugin = plugin
# set accelerator # set accelerator
if self.plugin and self.plugin.control_device(): if self.plugin and self.plugin.control_device():
self.accelerator = None self.accelerator = None
if device is not None: if device is not None:
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
else: else:
device = device or 'cuda' device = device or "cuda"
self.accelerator = Accelerator(device) self.accelerator = Accelerator(device)
# set precision # set precision
if self.plugin and self.plugin.control_precision(): if self.plugin and self.plugin.control_precision():
if mixed_precision is not None: if mixed_precision is not None:
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.mixed_precision = None self.mixed_precision = None
elif mixed_precision is None: elif mixed_precision is None:
self.mixed_precision = None self.mixed_precision = None
...@@ -95,7 +98,7 @@ class Booster: ...@@ -95,7 +98,7 @@ class Booster:
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
else: else:
raise ValueError( raise ValueError(
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}."
) )
if self.plugin is not None and self.plugin.control_checkpoint_io(): if self.plugin is not None and self.plugin.control_checkpoint_io():
...@@ -131,7 +134,8 @@ class Booster: ...@@ -131,7 +134,8 @@ class Booster:
# transform model for mixed precision # transform model for mixed precision
if self.plugin: if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler) model, optimizer, criterion, dataloader, lr_scheduler
)
if self.plugin and not self.plugin.control_device(): if self.plugin and not self.plugin.control_device():
# transform model for accelerator # transform model for accelerator
...@@ -154,13 +158,15 @@ class Booster: ...@@ -154,13 +158,15 @@ class Booster:
# TODO(frank lee): implement this method with plugin # TODO(frank lee): implement this method with plugin
optimizer.backward(loss) optimizer.backward(loss)
def execute_pipeline(self, def execute_pipeline(
data_iter: Iterator, self,
model: nn.Module, data_iter: Iterator,
criterion: Callable[[Any, Any], torch.Tensor], model: nn.Module,
optimizer: Optional[Optimizer] = None, criterion: Callable[[Any, Any], torch.Tensor],
return_loss: bool = True, optimizer: Optional[Optimizer] = None,
return_outputs: bool = False) -> Dict[str, Any]: return_loss: bool = True,
return_outputs: bool = False,
) -> Dict[str, Any]:
""" """
Execute forward & backward when utilizing pipeline parallel. Execute forward & backward when utilizing pipeline parallel.
Return loss or Huggingface style model outputs if needed. Return loss or Huggingface style model outputs if needed.
...@@ -185,8 +191,9 @@ class Booster: ...@@ -185,8 +191,9 @@ class Booster:
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
""" """
assert isinstance(self.plugin, assert isinstance(
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.' self.plugin, PipelinePluginBase
), f"The plugin {self.plugin.__class__.__name__} does not support pipeline."
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs) return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
...@@ -200,8 +207,10 @@ class Booster: ...@@ -200,8 +207,10 @@ class Booster:
Returns: Returns:
contextmanager: Context to disable gradient synchronization. contextmanager: Context to disable gradient synchronization.
""" """
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' assert (
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' self.plugin is not None
), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync."
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
return self.plugin.no_sync(model, optimizer) return self.plugin.no_sync(model, optimizer)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
...@@ -217,14 +226,16 @@ class Booster: ...@@ -217,14 +226,16 @@ class Booster:
""" """
self.checkpoint_io.load_model(model, checkpoint, strict) self.checkpoint_io.load_model(model, checkpoint, strict)
def save_model(self, def save_model(
model: Union[nn.Module, ModelWrapper], self,
checkpoint: str, model: Union[nn.Module, ModelWrapper],
shard: bool = False, checkpoint: str,
gather_dtensor: bool = True, shard: bool = False,
prefix: Optional[str] = None, gather_dtensor: bool = True,
size_per_shard: int = 1024, prefix: Optional[str] = None,
use_safetensors: bool = False) -> None: size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""Save model to checkpoint. """Save model to checkpoint.
Args: Args:
...@@ -239,13 +250,15 @@ class Booster: ...@@ -239,13 +250,15 @@ class Booster:
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
""" """
self.checkpoint_io.save_model(model, self.checkpoint_io.save_model(
checkpoint=checkpoint, model,
shard=shard, checkpoint=checkpoint,
gather_dtensor=gather_dtensor, shard=shard,
prefix=prefix, gather_dtensor=gather_dtensor,
size_per_shard=size_per_shard, prefix=prefix,
use_safetensors=use_safetensors) size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""Load optimizer from checkpoint. """Load optimizer from checkpoint.
...@@ -260,13 +273,15 @@ class Booster: ...@@ -260,13 +273,15 @@ class Booster:
""" """
self.checkpoint_io.load_optimizer(optimizer, checkpoint) self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self, def save_optimizer(
optimizer: Optimizer, self,
checkpoint: str, optimizer: Optimizer,
shard: bool = False, checkpoint: str,
gather_dtensor: bool = True, shard: bool = False,
prefix: Optional[str] = None, gather_dtensor: bool = True,
size_per_shard: int = 1024) -> None: prefix: Optional[str] = None,
size_per_shard: int = 1024,
) -> None:
""" """
Save optimizer to checkpoint. Save optimizer to checkpoint.
......
...@@ -6,16 +6,22 @@ from .fp16_torch import FP16TorchMixedPrecision ...@@ -6,16 +6,22 @@ from .fp16_torch import FP16TorchMixedPrecision
from .mixed_precision_base import MixedPrecision from .mixed_precision_base import MixedPrecision
__all__ = [ __all__ = [
'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', "MixedPrecision",
'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision' "mixed_precision_factory",
"FP16_Apex_MixedPrecision",
"FP16_Torch_MixedPrecision",
"FP32_MixedPrecision",
"BF16_MixedPrecision",
"FP8_MixedPrecision",
"FP16NaiveMixedPrecision",
] ]
_mixed_precision_mapping = { _mixed_precision_mapping = {
'fp16': FP16TorchMixedPrecision, "fp16": FP16TorchMixedPrecision,
'fp16_apex': FP16ApexMixedPrecision, "fp16_apex": FP16ApexMixedPrecision,
'fp16_naive': FP16NaiveMixedPrecision, "fp16_naive": FP16NaiveMixedPrecision,
'bf16': BF16MixedPrecision, "bf16": BF16MixedPrecision,
'fp8': FP8MixedPrecision "fp8": FP8MixedPrecision,
} }
...@@ -31,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision: ...@@ -31,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
return _mixed_precision_mapping[mixed_precision_type]() return _mixed_precision_mapping[mixed_precision_type]()
else: else:
raise ValueError( raise ValueError(
f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}' f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
) )
...@@ -23,16 +23,18 @@ class FP16ApexMixedPrecision(MixedPrecision): ...@@ -23,16 +23,18 @@ class FP16ApexMixedPrecision(MixedPrecision):
max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored. max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.
""" """
def __init__(self, def __init__(
opt_level: Optional[str] = "O1", self,
cast_model_type: torch.dtype = None, opt_level: Optional[str] = "O1",
patch_torch_functions: bool = None, cast_model_type: torch.dtype = None,
keep_batchnorm_fp32: Union[bool, str] = None, patch_torch_functions: bool = None,
master_weights: bool = None, keep_batchnorm_fp32: Union[bool, str] = None,
loss_scale: Union[float, str] = None, master_weights: bool = None,
cast_model_outputs: Any = None, loss_scale: Union[float, str] = None,
num_losses: Optional[int] = 1, cast_model_outputs: Any = None,
verbosity: int = 1, num_losses: Optional[int] = 1,
min_loss_scale: float = None, verbosity: int = 1,
max_loss_scale: float = 2.**24) -> None: min_loss_scale: float = None,
max_loss_scale: float = 2.0**24,
) -> None:
pass pass
...@@ -15,12 +15,14 @@ class FP16NaiveMixedPrecision(MixedPrecision): ...@@ -15,12 +15,14 @@ class FP16NaiveMixedPrecision(MixedPrecision):
verbose(bool): if set to `True`, will print debug info. verbose(bool): if set to `True`, will print debug info.
""" """
def __init__(self, def __init__(
log_num_zeros_in_grad: bool, self,
initial_scale: int, log_num_zeros_in_grad: bool,
growth_factor: int, initial_scale: int,
backoff_factor: float, growth_factor: int,
hysteresis: int, backoff_factor: float,
max_scale: int, hysteresis: int,
verbose: bool = None) -> None: max_scale: int,
verbose: bool = None,
) -> None:
pass pass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper ...@@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .mixed_precision_base import MixedPrecision from .mixed_precision_base import MixedPrecision
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule'] __all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
class TorchAMPOptimizer(OptimizerWrapper): class TorchAMPOptimizer(OptimizerWrapper):
...@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper): ...@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper):
calls that may cause the scale to increase. Default: 2000. calls that may cause the scale to increase. Default: 2000.
""" """
def __init__(self, def __init__(
optim: Optimizer, self,
init_scale: float = 2.**16, optim: Optimizer,
growth_factor: float = 2.0, init_scale: float = 2.0**16,
backoff_factor: float = 0.5, growth_factor: float = 2.0,
growth_interval: int = 2000) -> None: backoff_factor: float = 0.5,
growth_interval: int = 2000,
) -> None:
super().__init__(optim) super().__init__(optim)
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, self.scaler = torch.cuda.amp.GradScaler(
growth_factor=growth_factor, init_scale=init_scale,
backoff_factor=backoff_factor, growth_factor=growth_factor,
growth_interval=growth_interval) backoff_factor=backoff_factor,
growth_interval=growth_interval,
)
def backward(self, loss: Tensor, *args, **kwargs) -> None: def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss) scaled_loss = self.scale_loss(loss)
...@@ -60,12 +64,14 @@ class TorchAMPOptimizer(OptimizerWrapper): ...@@ -60,12 +64,14 @@ class TorchAMPOptimizer(OptimizerWrapper):
self.unscale_grad() self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs) super().clip_grad_by_value(clip_value, *args, **kwargs)
def clip_grad_by_norm(self, def clip_grad_by_norm(
max_norm: Union[float, int], self,
norm_type: Union[float, int] = 2.0, max_norm: Union[float, int],
error_if_nonfinite: bool = False, norm_type: Union[float, int] = 2.0,
*args, error_if_nonfinite: bool = False,
**kwargs) -> None: *args,
**kwargs,
) -> None:
self.unscale_grad() self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
...@@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision): ...@@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision):
calls that may cause the scale to increase. Default: 2000. calls that may cause the scale to increase. Default: 2000.
""" """
def __init__(self, def __init__(
init_scale: float = 2.**16, self,
growth_factor: float = 2.0, init_scale: float = 2.0**16,
backoff_factor: float = 0.5, growth_factor: float = 2.0,
growth_interval: int = 2000) -> None: backoff_factor: float = 0.5,
growth_interval: int = 2000,
) -> None:
super().__init__() super().__init__()
self.torch_amp_kwargs = dict(init_scale=init_scale, self.torch_amp_kwargs = dict(
growth_factor=growth_factor, init_scale=init_scale,
backoff_factor=backoff_factor, growth_factor=growth_factor,
growth_interval=growth_interval) backoff_factor=backoff_factor,
growth_interval=growth_interval,
def configure(self, )
model: nn.Module,
optimizer: Optional[Optimizer] = None, def configure(
criterion: Optional[Callable] = None, self,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]: model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model) model = TorchAMPModule(model)
if optimizer is not None: if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
......
...@@ -4,11 +4,12 @@ from .low_level_zero_plugin import LowLevelZeroPlugin ...@@ -4,11 +4,12 @@ from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin'] __all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
import torch import torch
from packaging import version from packaging import version
if version.parse(torch.__version__) >= version.parse('1.12.0'): if version.parse(torch.__version__) >= version.parse("1.12.0"):
from .torch_fsdp_plugin import TorchFSDPPlugin from .torch_fsdp_plugin import TorchFSDPPlugin
__all__.append('TorchFSDPPlugin')
__all__.append("TorchFSDPPlugin")
...@@ -10,25 +10,19 @@ from .plugin_base import Plugin ...@@ -10,25 +10,19 @@ from .plugin_base import Plugin
class DPPluginBase(Plugin): class DPPluginBase(Plugin):
"""This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation. """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation."""
"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
assert dist.is_initialized( assert (
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' dist.is_initialized()
), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment"
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
def prepare_dataloader(self, def prepare_dataloader(
dataset, self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
batch_size, ):
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r""" r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
...@@ -60,11 +54,13 @@ class DPPluginBase(Plugin): ...@@ -60,11 +54,13 @@ class DPPluginBase(Plugin):
torch.manual_seed(worker_seed) torch.manual_seed(worker_seed)
random.seed(worker_seed) random.seed(worker_seed)
return DataLoader(dataset, return DataLoader(
batch_size=batch_size, dataset,
sampler=sampler, batch_size=batch_size,
worker_init_fn=seed_worker, sampler=sampler,
drop_last=drop_last, worker_init_fn=seed_worker,
pin_memory=pin_memory, drop_last=drop_last,
num_workers=num_workers, pin_memory=pin_memory,
**_kwargs) num_workers=num_workers,
**_kwargs,
)
...@@ -27,14 +27,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats ...@@ -27,14 +27,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
__all__ = ['GeminiPlugin'] __all__ = ["GeminiPlugin"]
SUPPORTED_PRECISION = ['fp16', 'bf16'] SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16} PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
class GeminiCheckpointIO(GeneralCheckpointIO): class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
...@@ -74,13 +73,15 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -74,13 +73,15 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
""" """
super().load_unsharded_optimizer(optimizer, checkpoint) super().load_unsharded_optimizer(optimizer, checkpoint)
def save_sharded_model(self, def save_sharded_model(
model: GeminiDDP, self,
checkpoint_path: str, model: GeminiDDP,
gather_dtensor: bool = False, checkpoint_path: str,
prefix: Optional[str] = None, gather_dtensor: bool = False,
max_shard_size: int = 1024, prefix: Optional[str] = None,
use_safetensors: bool = False): max_shard_size: int = 1024,
use_safetensors: bool = False,
):
""" """
Save sharded model. Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes. As there is communication when getting state dict, model.state_dict() must be called on all processes.
...@@ -97,34 +98,37 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -97,34 +98,37 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Save shards of optimizer states. # Save shards of optimizer states.
is_master = self.coordinator.is_master() is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, total_size = save_state_dict_shards(
checkpoint=checkpoint_path, sharded_state_dict=state_dict_shard,
index_file=index_file, checkpoint=checkpoint_path,
base_filename=weights_name, index_file=index_file,
is_master=is_master, base_filename=weights_name,
use_safetensors=use_safetensors) is_master=is_master,
use_safetensors=use_safetensors,
)
# only save the index file on the master rank # only save the index file on the master rank
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path) save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The model is split into checkpoint shards. "
f"index located at {save_index_file}.") f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
def load_sharded_model(self, )
model: GeminiDDP,
checkpoint_index_file: Path, def load_sharded_model(
strict: bool = False, self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
use_safetensors: bool = False): ):
""" """
Load shard model, load model from multiple files. Load shard model, load model from multiple files.
""" """
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, def save_sharded_optimizer(
size_per_shard: int): self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
""" """
Save sharded optimizer state dict to checkpoint folder. Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes. As there is communication when getting state dict, this must be called on all processes.
...@@ -153,20 +157,24 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -153,20 +157,24 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Save shards of optimizer states. # Save shards of optimizer states.
is_master = self.coordinator.is_master() is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, total_size = save_state_dict_shards(
checkpoint=checkpoint, sharded_state_dict=state_dict_shard,
index_file=index_file, checkpoint=checkpoint,
base_filename=states_name, index_file=index_file,
is_master=is_master, base_filename=states_name,
use_safetensors=False) is_master=is_master,
use_safetensors=False,
)
# Wrap up index file. Only save it on master rank. # Wrap up index file. Only save it on master rank.
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The optimizer is going to be split to checkpoint shards. "
f"index located at {save_index_file}.") f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
""" """
...@@ -185,8 +193,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -185,8 +193,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Load param_groups. # Load param_groups.
param_group_path = ckpt_index_file.get_param_group_filename() param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None: if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ raise RuntimeError(
Lacking param group file under current directory.') f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_param_groups = torch.load(param_group_path) saved_param_groups = torch.load(param_group_path)
optimizer.load_param_groups(saved_param_groups) optimizer.load_param_groups(saved_param_groups)
...@@ -274,11 +284,11 @@ class GeminiPlugin(DPPluginBase): ...@@ -274,11 +284,11 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict: Optional[dict] = None, chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None, chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static", placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16", precision: str = "fp16",
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
...@@ -300,7 +310,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -300,7 +310,7 @@ class GeminiPlugin(DPPluginBase):
verbose: bool = False, verbose: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
self.gemini_config = dict( self.gemini_config = dict(
chunk_config_dict=chunk_config_dict, chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()), chunk_init_device=(chunk_init_device or get_current_device()),
...@@ -319,16 +329,20 @@ class GeminiPlugin(DPPluginBase): ...@@ -319,16 +329,20 @@ class GeminiPlugin(DPPluginBase):
memstats=memstats, memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision], mixed_precision=PRECISION_STR_TO_DTYPE[precision],
) )
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) self.zero_optim_config = dict(
self.optim_kwargs = dict(initial_scale=initial_scale, gpu_margin_mem_ratio=gpu_margin_mem_ratio,
growth_factor=growth_factor, )
backoff_factor=backoff_factor, self.optim_kwargs = dict(
growth_interval=growth_interval, initial_scale=initial_scale,
hysteresis=hysteresis, growth_factor=growth_factor,
min_scale=min_scale, backoff_factor=backoff_factor,
max_scale=max_scale, growth_interval=growth_interval,
max_norm=max_norm, hysteresis=hysteresis,
norm_type=norm_type) min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
)
self.verbose = verbose self.verbose = verbose
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
...@@ -344,7 +358,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -344,7 +358,7 @@ class GeminiPlugin(DPPluginBase):
return True return True
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ['cuda'] return ["cuda"]
def configure( def configure(
self, self,
...@@ -354,7 +368,6 @@ class GeminiPlugin(DPPluginBase): ...@@ -354,7 +368,6 @@ class GeminiPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
# convert model to sync bn # convert model to sync bn
# FIXME(ver217): gemini does not support sync bn # FIXME(ver217): gemini does not support sync bn
...@@ -368,13 +381,10 @@ class GeminiPlugin(DPPluginBase): ...@@ -368,13 +381,10 @@ class GeminiPlugin(DPPluginBase):
# wrap the model with Gemini # wrap the model with Gemini
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if optimizer is not None and \ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer(
optimizer = GeminiOptimizer(optimizer, optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
model.unwrap(), )
**self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
......
...@@ -37,10 +37,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): ...@@ -37,10 +37,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class HybridParallelModule(ModelWrapper): class HybridParallelModule(ModelWrapper):
def __init__(
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, self,
ddp_config: dict, custom_policy: Policy) -> None: module: Module,
precision: str,
shard_config: ShardConfig,
dp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group self.dp_group = dp_group
...@@ -54,13 +60,14 @@ class HybridParallelModule(ModelWrapper): ...@@ -54,13 +60,14 @@ class HybridParallelModule(ModelWrapper):
for shared_param in self.shared_params: for shared_param in self.shared_params:
if len(shared_param) > 0: if len(shared_param) > 0:
self.shared_param_process_groups.append( self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
)
# setting mixed_precision # setting mixed_precision
self.mixed_precision = None self.mixed_precision = None
if precision == 'fp16': if precision == "fp16":
self.mixed_precision = torch.float16 self.mixed_precision = torch.float16
elif precision == 'bf16': elif precision == "bf16":
self.mixed_precision = torch.bfloat16 self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None: if self.mixed_precision is not None:
module = module.to(self.mixed_precision) module = module.to(self.mixed_precision)
...@@ -123,22 +130,21 @@ def get_param_info(optim: Optimizer): ...@@ -123,22 +130,21 @@ def get_param_info(optim: Optimizer):
if optim is None: if optim is None:
return {} return {}
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0 start_index = 0
for group in optim.param_groups: for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
packed_group["params"] = []
packed_group = {k: v for k, v in group.items() if k != 'params'} for param_id, param in enumerate(group["params"], start_index):
packed_group['params'] = []
for param_id, param in enumerate(group['params'], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group['params'].append(param_id) packed_group["params"].append(param_id)
param_info['param2id'][id(param)] = param_id param_info["param2id"][id(param)] = param_id
param_info['id2param'][param_id] = id(param) param_info["id2param"][param_id] = id(param)
param_info['param2shape'][id(param)] = original_shape param_info["param2shape"][id(param)] = original_shape
param_info['param_groups'].append(packed_group) param_info["param_groups"].append(packed_group)
start_index += len(group['params']) start_index += len(group["params"])
return param_info return param_info
...@@ -147,13 +153,12 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): ...@@ -147,13 +153,12 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters()) model_params = set(model.parameters())
new_param_groups = [] new_param_groups = []
for group in optim.param_groups: for group in optim.param_groups:
params = [p for p in group['params'] if p in model_params] params = [p for p in group["params"] if p in model_params]
new_param_groups.append({**group, 'params': params}) new_param_groups.append({**group, "params": params})
optim.__setstate__({'param_groups': new_param_groups}) optim.__setstate__({"param_groups": new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper): class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info self.param_info = param_info
if use_pipeline: if use_pipeline:
...@@ -162,60 +167,87 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): ...@@ -162,60 +167,87 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
def __init__(self, self,
optim: Optimizer, optim: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict, param_info: OrderedDict,
precision: str = 'fp16', precision: str = "fp16",
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
growth_factor: float = 2, growth_factor: float = 2,
backoff_factor: float = 0.5, backoff_factor: float = 0.5,
growth_interval: int = 1000, growth_interval: int = 1000,
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0): max_norm: float = 0,
):
self.param_info = param_info self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, super().__init__(
hysteresis, max_scale, max_norm) optim,
precision,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
max_norm,
)
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__( def __init__(
self, self,
optimizer: Optimizer, optimizer: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict, param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1, min_scale: int = 1,
growth_factor: float = 2., growth_factor: float = 2.0,
backoff_factor: float = .5, backoff_factor: float = 0.5,
growth_interval: int = 2000, growth_interval: int = 2000,
hysteresis: int = 2, hysteresis: int = 2,
max_scale: int = 2**24, max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False, verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True, overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None): forced_dtype: Optional[torch.dtype] = None,
):
self.param_info = param_info self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optimizer, model) init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, super().__init__(
hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, optimizer,
overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, initial_scale,
forced_dtype) min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
clip_grad_norm,
verbose,
reduce_bucket_size,
communication_dtype,
overlap_communication,
partition_grad,
cpu_offload,
dp_process_group,
tp_process_group,
forced_dtype,
)
class HybridParallelPlugin(PipelinePluginBase): class HybridParallelPlugin(PipelinePluginBase):
...@@ -276,46 +308,47 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -276,46 +308,47 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
""" """
def __init__(self, def __init__(
tp_size: int, self,
pp_size: int, tp_size: int,
precision: str = 'fp16', pp_size: int,
zero_stage: int = 0, precision: str = "fp16",
enable_all_optimization: bool = False, zero_stage: int = 0,
enable_fused_normalization: bool = False, enable_all_optimization: bool = False,
enable_flash_attention: bool = False, enable_fused_normalization: bool = False,
enable_jit_fused: bool = False, enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False, enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False, enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None, enable_sequence_overlap: bool = False,
microbatch_size: Optional[int] = None, num_microbatches: Optional[int] = None,
initial_scale: float = 2**16, microbatch_size: Optional[int] = None,
min_scale: float = 1, initial_scale: float = 2**16,
growth_factor: float = 2, min_scale: float = 1,
backoff_factor: float = 0.5, growth_factor: float = 2,
growth_interval: int = 1000, backoff_factor: float = 0.5,
hysteresis: int = 2, growth_interval: int = 1000,
max_scale: float = 2**32, hysteresis: int = 2,
max_norm: float = 0, max_scale: float = 2**32,
broadcast_buffers: bool = True, max_norm: float = 0,
ddp_bucket_cap_mb: int = 25, broadcast_buffers: bool = True,
find_unused_parameters: bool = False, ddp_bucket_cap_mb: int = 25,
check_reduction: bool = False, find_unused_parameters: bool = False,
gradient_as_bucket_view: bool = False, check_reduction: bool = False,
static_graph: bool = False, gradient_as_bucket_view: bool = False,
zero_bucket_size_in_m: int = 12, static_graph: bool = False,
cpu_offload: bool = False, zero_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None, cpu_offload: bool = False,
overlap_communication: bool = True, communication_dtype: Optional[torch.dtype] = None,
custom_policy: Policy = None) -> None: overlap_communication: bool = True,
custom_policy: Policy = None,
) -> None:
super().__init__() super().__init__()
assert dist.get_world_size() % ( assert (
tp_size * pp_size dist.get_world_size() % (tp_size * pp_size) == 0
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism: if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
self.tp_size = tp_size self.tp_size = tp_size
self.pp_size = pp_size self.pp_size = pp_size
...@@ -334,24 +367,28 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -334,24 +367,28 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2) assert zero_stage in (0, 1, 2)
if self.pp_size > 1: if self.pp_size > 1:
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' assert (
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, self.schedule = OneForwardOneBackwardSchedule(
num_microbatches=num_microbatches, self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
microbatch_size=microbatch_size) )
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, self.shard_config = ShardConfig(
pipeline_stage_manager=self.stage_manager, tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.tp_size > 1, pipeline_stage_manager=self.stage_manager,
enable_all_optimization=self.enable_all_optimization, enable_tensor_parallelism=self.tp_size > 1,
enable_fused_normalization=self.enable_fused_normalization, enable_all_optimization=self.enable_all_optimization,
enable_flash_attention=self.enable_flash_attention, enable_fused_normalization=self.enable_fused_normalization,
enable_jit_fused=self.enable_jit_fused, enable_flash_attention=self.enable_flash_attention,
enable_sequence_parallelism=enable_sequence_parallelism, enable_jit_fused=self.enable_jit_fused,
enable_sequence_overlap=enable_sequence_overlap) enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
)
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
...@@ -362,18 +399,22 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -362,18 +399,22 @@ class HybridParallelPlugin(PipelinePluginBase):
max_scale=max_scale, max_scale=max_scale,
) )
self.ddp_config = dict(broadcast_buffers=broadcast_buffers, self.ddp_config = dict(
bucket_cap_mb=ddp_bucket_cap_mb, broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters, bucket_cap_mb=ddp_bucket_cap_mb,
check_reduction=check_reduction, find_unused_parameters=find_unused_parameters,
gradient_as_bucket_view=gradient_as_bucket_view, check_reduction=check_reduction,
static_graph=static_graph) gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, self.zero_config = dict(
communication_dtype=communication_dtype, reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
overlap_communication=overlap_communication, communication_dtype=communication_dtype,
cpu_offload=cpu_offload, overlap_communication=overlap_communication,
partition_grad=(self.zero_stage == 2)) cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
)
self.max_norm = max_norm self.max_norm = max_norm
...@@ -382,10 +423,10 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -382,10 +423,10 @@ class HybridParallelPlugin(PipelinePluginBase):
return self.pp_size > 1 return self.pp_size > 1
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ['cuda'] return ["cuda"]
def supported_precisions(self) -> List[str]: def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16', 'fp32'] return ["fp16", "bf16", "fp32"]
def control_device(self) -> bool: def control_device(self) -> bool:
return True return True
...@@ -410,57 +451,67 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -410,57 +451,67 @@ class HybridParallelPlugin(PipelinePluginBase):
param_info = get_param_info(optimizer) param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, model = HybridParallelModule(
self.ddp_config, self.custom_policy) model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0: if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']: if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(optimizer, optimizer = HybridParallelAMPOptimizer(
model, optimizer,
use_pipeline=self.enable_pipeline_parallelism, model,
param_info=param_info, use_pipeline=self.enable_pipeline_parallelism,
precision=self.precision, param_info=param_info,
max_norm=self.max_norm, precision=self.precision,
**self.amp_config) max_norm=self.max_norm,
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, **self.amp_config,
optimizer.master_to_working_map) )
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else: else:
optimizer = HybridParallelNaiveOptimizer(optimizer, optimizer = HybridParallelNaiveOptimizer(
model, optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
use_pipeline=self.enable_pipeline_parallelism, )
param_info=param_info)
else: else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer, optimizer = HybridParallelZeroOptimizer(
model, optimizer,
use_pipeline=self.enable_pipeline_parallelism, model,
param_info=param_info, use_pipeline=self.enable_pipeline_parallelism,
dp_process_group=self.dp_group, param_info=param_info,
tp_process_group=self.tp_group, dp_process_group=self.dp_group,
verbose=True, tp_process_group=self.tp_group,
clip_grad_norm=self.max_norm, verbose=True,
**self.zero_config, clip_grad_norm=self.max_norm,
**self.amp_config) **self.zero_config,
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, **self.amp_config,
optimizer._param_store.master_to_working_param) )
self.checkpoint_io.link_master_and_working_param(
optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self, def execute_pipeline(
data_iter: Iterator, self,
model: HybridParallelModule, data_iter: Iterator,
criterion: Callable[[Any, Any], torch.Tensor], model: HybridParallelModule,
optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, criterion: Callable[[Any, Any], torch.Tensor],
HybridParallelZeroOptimizer]] = None, optimizer: Optional[
return_loss: bool = True, Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer]
return_outputs: bool = False) -> dict: ] = None,
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' return_loss: bool = True,
return_outputs: bool = False,
) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
# return loss or outputs if needed # return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx: with ctx:
outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss, outputs = self.schedule.forward_backward_step(
return_outputs) model, data_iter, criterion, optimizer, return_loss, return_outputs
)
model.sync_shared_params() model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer): if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad() optimizer.sync_grad()
...@@ -468,15 +519,9 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -468,15 +519,9 @@ class HybridParallelPlugin(PipelinePluginBase):
model.sync_grads() model.sync_grads()
return outputs return outputs
def prepare_dataloader(self, def prepare_dataloader(
dataset, self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
batch_size, ):
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r""" r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
...@@ -499,10 +544,9 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -499,10 +544,9 @@ class HybridParallelPlugin(PipelinePluginBase):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
""" """
_kwargs = kwargs.copy() _kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, sampler = DistributedSampler(
num_replicas=self.pg_mesh.size(DP_AXIS), dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
rank=self.pg_mesh.coordinate(DP_AXIS), )
shuffle=shuffle)
# Deterministic dataloader # Deterministic dataloader
def seed_worker(worker_id): def seed_worker(worker_id):
...@@ -511,14 +555,16 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -511,14 +555,16 @@ class HybridParallelPlugin(PipelinePluginBase):
torch.manual_seed(worker_seed) torch.manual_seed(worker_seed)
random.seed(worker_seed) random.seed(worker_seed)
return DataLoader(dataset, return DataLoader(
batch_size=batch_size, dataset,
sampler=sampler, batch_size=batch_size,
worker_init_fn=seed_worker, sampler=sampler,
drop_last=drop_last, worker_init_fn=seed_worker,
pin_memory=pin_memory, drop_last=drop_last,
num_workers=num_workers, pin_memory=pin_memory,
**_kwargs) num_workers=num_workers,
**_kwargs,
)
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
......
import logging import logging
import os import os
import warnings
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import MethodType from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -33,7 +31,7 @@ from colossalai.zero import LowLevelZeroOptimizer ...@@ -33,7 +31,7 @@ from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO from .torch_ddp_plugin import TorchDDPCheckpointIO
__all__ = ['LowLevelZeroPlugin'] __all__ = ["LowLevelZeroPlugin"]
def _convert_floating_point(x, dtype: torch.dtype = torch.float16): def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
...@@ -42,17 +40,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): ...@@ -42,17 +40,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
return x return x
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None: def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module) super().__init__(module)
self.dtype = None self.dtype = None
if precision == 'fp16': if precision == "fp16":
self.dtype = torch.float16 self.dtype = torch.float16
elif precision == 'bf16': elif precision == "bf16":
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
if self.dtype is not None: if self.dtype is not None:
module = module.to(self.dtype) module = module.to(self.dtype)
...@@ -74,7 +71,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): ...@@ -74,7 +71,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
"""Save optimizer to checkpoint but only on master process. """Save optimizer to checkpoint but only on master process.
...@@ -91,12 +87,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): ...@@ -91,12 +87,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_optimizer(self, def save_sharded_optimizer(
optimizer: OptimizerWrapper, self,
checkpoint: str, optimizer: OptimizerWrapper,
gather_dtensor: bool = False, checkpoint: str,
prefix: str = None, gather_dtensor: bool = False,
size_per_shard: int = 1024): prefix: str = None,
size_per_shard: int = 1024,
):
""" """
Save sharded Zero-optimizer checkpoint under the given checkpointing path. Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path: The following files will be created under the path:
...@@ -148,9 +146,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): ...@@ -148,9 +146,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. " logging.info(
f"You can find where each parameters has been saved in the " f"The optimizer is going to be split to checkpoint shards. "
f"index located at {save_index_file}.") f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file. """Load sharded optimizer with the given path to index file.
...@@ -170,8 +170,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): ...@@ -170,8 +170,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# Load param_groups # Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename() param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None: if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ raise RuntimeError(
Lacking param group file under current directory.') f"Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory."
)
id_map = load_param_groups_into_optimizer(optimizer, param_group_path) id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
...@@ -181,9 +183,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): ...@@ -181,9 +183,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# shard state dict # shard state dict
for param_idx, state in state_dict.items(): for param_idx, state in state_dict.items():
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step': if isinstance(v, torch.Tensor) and k != "step":
padding_size = (self.coordinator.world_size - padding_size = (
v.numel() % self.coordinator.world_size) % self.coordinator.world_size self.coordinator.world_size - v.numel() % self.coordinator.world_size
) % self.coordinator.world_size
with torch.no_grad(): with torch.no_grad():
v = v.flatten() v = v.flatten()
if padding_size > 0: if padding_size > 0:
...@@ -194,33 +197,39 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): ...@@ -194,33 +197,39 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
sharded_optimizer_loading_epilogue(optimizer) sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, def save_unsharded_model(
use_safetensors: bool): self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
):
assert isinstance(model, LowLevelZeroModel) assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors) super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(self, def save_sharded_model(
model: nn.Module, self,
checkpoint_path: str, model: nn.Module,
gather_dtensor: bool = True, checkpoint_path: str,
prefix: Optional[str] = None, gather_dtensor: bool = True,
max_shard_size: int = 1024, prefix: Optional[str] = None,
use_safetensors: bool = False): max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel) assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, super().save_sharded_model(
use_safetensors) model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True): def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel) assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict) super().load_unsharded_model(model.module, checkpoint, strict)
model.update_master_params() model.update_master_params()
def load_sharded_model(self, def load_sharded_model(
model: LowLevelZeroModel, self,
checkpoint_index_file: Path, model: LowLevelZeroModel,
strict: bool = False, checkpoint_index_file: Path,
use_safetensors: bool = False, strict: bool = False,
load_sub_module: bool = True): use_safetensors: bool = False,
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel) assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module) super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params() model.update_master_params()
...@@ -264,7 +273,7 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -264,7 +273,7 @@ class LowLevelZeroPlugin(DPPluginBase):
def __init__( def __init__(
self, self,
stage: int = 1, stage: int = 1,
precision: str = 'fp16', precision: str = "fp16",
initial_scale: float = 2**32, initial_scale: float = 2**32,
min_scale: float = 1, min_scale: float = 1,
growth_factor: float = 2, growth_factor: float = 2,
...@@ -281,9 +290,9 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -281,9 +290,9 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose: bool = False, verbose: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now' assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
self.stage = stage self.stage = stage
self.precision = precision self.precision = precision
self.zero_optim_kwargs = dict( self.zero_optim_kwargs = dict(
...@@ -319,7 +328,7 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -319,7 +328,7 @@ class LowLevelZeroPlugin(DPPluginBase):
return True return True
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ['cuda'] return ["cuda"]
def configure( def configure(
self, self,
...@@ -329,15 +338,13 @@ class LowLevelZeroPlugin(DPPluginBase): ...@@ -329,15 +338,13 @@ class LowLevelZeroPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision) model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
not isinstance(optimizer, OptimizerWrapper): optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer, optimizer, **self.zero_optim_kwargs, verbose=self.verbose
**self.zero_optim_kwargs, )
verbose=self.verbose)
# inject update_master_params # inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model) model.update_master_params = MethodType(optimizer.update_master_params, model)
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -9,11 +9,10 @@ from torch.utils.data import DataLoader, Dataset ...@@ -9,11 +9,10 @@ from torch.utils.data import DataLoader, Dataset
from colossalai.checkpoint_io import CheckpointIO from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
__all__ = ['Plugin'] __all__ = ["Plugin"]
class Plugin(ABC): class Plugin(ABC):
@abstractmethod @abstractmethod
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
pass pass
...@@ -51,33 +50,31 @@ class Plugin(ABC): ...@@ -51,33 +50,31 @@ class Plugin(ABC):
""" """
Whether the plugin controls the checkpoint io Whether the plugin controls the checkpoint io
""" """
pass
@abstractmethod @abstractmethod
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
""" """
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
""" """
pass
@abstractmethod @abstractmethod
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
""" """
Context manager to disable gradient synchronization. Context manager to disable gradient synchronization.
""" """
pass
@abstractmethod @abstractmethod
def prepare_dataloader(self, def prepare_dataloader(
dataset: Dataset, self,
batch_size: int, dataset: Dataset,
shuffle: bool = False, batch_size: int,
seed: int = 1024, shuffle: bool = False,
drop_last: bool = False, seed: int = 1024,
pin_memory: bool = False, drop_last: bool = False,
num_workers: int = 0, pin_memory: bool = False,
**kwargs): num_workers: int = 0,
**kwargs,
):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by """Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` `torch.utils.data.DataLoader`
""" """
pass
...@@ -9,13 +9,14 @@ from .plugin_base import Plugin ...@@ -9,13 +9,14 @@ from .plugin_base import Plugin
class PipelinePluginBase(Plugin): class PipelinePluginBase(Plugin):
@abstractmethod @abstractmethod
def execute_pipeline(self, def execute_pipeline(
data_iter: Iterator, self,
model: ModelWrapper, data_iter: Iterator,
criterion: Callable[[Any, Any], torch.Tensor], model: ModelWrapper,
optimizer: Optional[OptimizerWrapper] = None, criterion: Callable[[Any, Any], torch.Tensor],
return_loss: bool = True, optimizer: Optional[OptimizerWrapper] = None,
return_outputs: bool = False) -> dict: return_loss: bool = True,
return_outputs: bool = False,
) -> dict:
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment