Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
......@@ -8,8 +8,11 @@ import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
"transpose_partition_dim",
"update_partition_dim",
"enumerate_all_possible_1d_sharding",
"enumerate_all_possible_2d_sharding",
"generate_sharding_size",
]
......@@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
assert len(sharding_spec.entire_shape) >= 2, \
'The entire_shape of the sharding spec must have at least 2 dimensions'
assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions"
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
......@@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
return sharding_spec
def update_partition_dim(sharding_spec: ShardingSpec,
dim_mapping: Dict[int, int],
physical_shape: torch.Size,
inplace: bool = False):
def update_partition_dim(
sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False
):
"""
This method is used to update the partition dim dict from the logical one to the physical one.
......@@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec,
new_dim_partition_dict[tensor_dim] = mesh_dims
# update sharding spec
current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=new_dim_partition_dict)
current_sharding_spec.__init__(
device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict
)
return current_sharding_spec
......
......@@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
......@@ -40,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
return new_shape
def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_output_dim: int, chunk_size=2) -> str:
"""
Generate chunk loop start
......@@ -52,7 +63,7 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup
Args:
chunk_input (List[Node]): chunk input node
chunk_output (Node): chunk output node
chunk_ouput_dim (int): chunk output node chunk dim
chunk_output_dim (int): chunk output node chunk dim
chunk_size (int): chunk size. Defaults to 2.
Returns:
......@@ -64,23 +75,36 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
input_node.name)
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
shape_str,
input_node.name,
input_node.name,
)
tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
input_node.name, input_node.name)
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
chunk_output[i].name,
shape_str,
input_node.name,
input_node.name,
)
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_ouput_dim[0]]
chunk_shape = out_shape[chunk_output_dim[0]]
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
return context
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
def _gen_loop_end(
chunk_inputs: List[Node],
chunk_non_compute_inputs: List[Node],
node_list: List[Node],
chunk_outputs_idx: int,
chunk_outputs_non_tensor: List[Node],
search_chunk: SearchChunk,
) -> str:
"""
Generate chunk loop end
......@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
if (
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body
......@@ -203,11 +229,12 @@ def _add_node_slice(
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
get_node_shape(chunk_node))
chunk_slice = _gen_chunk_slice_dim(
chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
)
if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])):
for i in range(len(chunk_node.meta["tensor_meta"])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
......@@ -216,13 +243,15 @@ def _add_node_slice(
return body
def emit_code_with_chunk(body: List[str],
def emit_code_with_chunk(
body: List[str],
nodes: Iterable[Node],
emit_node_func: Callable,
delete_unused_value_func: Callable,
search_chunk: SearchChunk,
chunk_infos: List,
eval_mem: bool = False):
eval_mem: bool = False,
):
"""
Emit code with chunk according to chunk_infos.
......@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"],
))
)
)
if within_chunk_region:
emit_node_func(node, body)
......@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
% (node.name)
)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
......@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
% (node.name))
% (node.name)
)
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
_gen_loop_end(
chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
node_list,
chunk_ends[region_idx],
chunk_outputs_non_tensor[region_idx],
search_chunk,
)
)
within_chunk_region = False
node_idx += 1
......@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
def __init__(self,
def __init__(
self,
meta_graph,
max_memory: int = None,
print_mem: bool = False,
print_progress: bool = False,
eval_mem: bool = False) -> None:
eval_mem: bool = False,
) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
......@@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE:
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
......@@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE:
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
......@@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE:
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
......@@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE:
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})")
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
body.append(
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
)
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
and node.args[1].isidentifier() and len(node.args) == 2):
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return
body.append(
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
)
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
......@@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
self.eval_mem)
emit_code_with_chunk(
body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
......
import copy
from typing import Any, Callable, Dict, Iterable, List, Tuple
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import NodeMgr, get_node_shape, is_non_memory_node
......@@ -62,12 +59,9 @@ class EstimateMemory(object):
delete_node_dict[node] = max(node_user_idx)
return delete_node_dict
def _remove_deactive_node(self,
user_idx: int,
user: Node,
active_nodes: List,
delete_node_dict: List,
kept_nodes: List = None) -> None:
def _remove_deactive_node(
self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None
) -> None:
"""
remove deactivate nodes from active nodes
"""
......@@ -184,7 +178,6 @@ class EstimateMemory(object):
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_mgr.get_node_list()):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
......@@ -193,8 +186,9 @@ class EstimateMemory(object):
# determine chunk ratio for current node
if chunk_within:
chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
chunk_sizes[chunk_region_idx])
chunk_ratio = self._get_chunk_ratio(
node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx]
)
# add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio)
......
......@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
......@@ -121,8 +121,10 @@ class SearchChunk(object):
# check if peak node already in chunk info
if chunk_regions is not None:
for i in chunk_regions:
if i["region"][0] < peak_region[0] <= i["region"][1] or \
i["region"][0] < peak_region[1] <= i["region"][1]:
if (
i["region"][0] < peak_region[0] <= i["region"][1]
or i["region"][0] < peak_region[1] <= i["region"][1]
):
return None
active_node_num = [len(i) for i in active_node]
......@@ -146,9 +148,9 @@ class SearchChunk(object):
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]:
chunk_region_start = region[1] + 1
elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]:
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
......@@ -180,8 +182,9 @@ class SearchChunk(object):
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
end_idx):
if not self.trace_flow.check_region_start_end(
start_node, start_dim, start_idx, end_node, end_dim, end_idx
):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
......@@ -215,7 +218,8 @@ class SearchChunk(object):
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
self.node_mgr.get_node_by_idx(end_idx)):
self.node_mgr.get_node_by_idx(end_idx)
):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
......@@ -279,15 +283,18 @@ class SearchChunk(object):
chunk_infos.append(chunk_info)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos)
self.node_mgr.get_node_list(), chunk_infos
)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
get_logger().info(
"AutoChunk find chunk region %d = (%d, %d)"
% (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])
)
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
chunk_infos,
print_mem=True)
self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos, print_mem=True
)
return chunk_infos
......@@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object):
def __init__(
self,
trace_indice: TraceIndice,
......@@ -57,16 +56,18 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
}
)
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
......@@ -90,13 +91,15 @@ class SelectChunk(object):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
chunk_infos)
chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
)
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
......@@ -109,9 +112,10 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
......@@ -139,8 +143,10 @@ class SelectChunk(object):
return None
# get max possible chunk region
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
max_possible_chunk_region = (
min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]),
)
# get mem for chunk region
regions_dict_list = []
......@@ -149,15 +155,17 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict_list.append({
regions_dict_list.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
}
)
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
......@@ -175,7 +183,9 @@ class SelectChunk(object):
return False
for i in chunk_infos:
region = i["region"]
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
(chunk_region_start < region[0] and chunk_region_end < region[0])):
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False
return True
......@@ -16,7 +16,6 @@ from .utils import (
class TraceFlow(object):
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
......@@ -64,7 +63,7 @@ class TraceFlow(object):
return False
return True
def _assgin_single_node_flow(
def _assign_single_node_flow(
self,
arg_node: Node,
start_idx: int,
......@@ -177,7 +176,7 @@ class TraceFlow(object):
if get_node_shape(arg) is None:
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(
flow_flag = self._assign_single_node_flow(
arg,
start_idx,
end_idx,
......@@ -315,7 +314,7 @@ class TraceFlow(object):
chunk_info["args"]["prepose_nodes"] = prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleting them in the loop
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
......@@ -328,7 +327,8 @@ class TraceFlow(object):
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
)
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
......@@ -366,13 +366,14 @@ class TraceFlow(object):
# find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
# reassgin reshape size, some size may have changed due to chunk
chunk_info = self._reassgin_reshape_size(chunk_info)
# reassign reshape size, some size may have changed due to chunk
chunk_info = self._reassign_reshape_size(chunk_info)
return chunk_info
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
chunk_info: Dict):
def _get_other_output_info(
self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
......@@ -384,8 +385,8 @@ class TraceFlow(object):
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
......@@ -421,17 +422,18 @@ class TraceFlow(object):
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
)
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
chunk_info["outputs_dim"].append(output_dim)
return True
def _reassgin_reshape_size(self, chunk_info):
def _reassign_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
reassgin those changed shape
reassign those changed shape
"""
chunk_region = chunk_info["region"]
reshape_size = {}
......@@ -443,8 +445,11 @@ class TraceFlow(object):
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
reshape_args[0].meta['fwd_out']) > 1:
if (
len(reshape_args) == 1
and get_node_shape(reshape_args[0]) is None
and len(reshape_args[0].meta["fwd_out"]) > 1
):
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
......@@ -462,16 +467,17 @@ class TraceFlow(object):
chunk_info["reshape_size"] = reshape_size
return chunk_info
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
end_idx: int) -> bool:
def check_region_start_end(
self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
return False
# dim size cannot be 1
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
return False
# must have users
if len(end_node.users) == 0:
......
import copy
from typing import Dict, List, Tuple
from typing import Dict, List
from torch.fx.node import Node
......@@ -18,7 +18,7 @@ class TraceIndice(object):
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
This class will record every node's dims' indice, compute and source.
Attibutes:
Attributes:
node_list (List)
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
indice_view_list (Dict): not used for now
......@@ -397,7 +397,7 @@ class TraceIndice(object):
input_node = node.args[0]
assert len(get_node_shape(input_node)) == 4
# assgin index
# assign index
self._assign_indice_as_input(node, node_idx, input_node)
self._del_dim(node_idx, 1)
self._add_dim(node_idx, 1)
......@@ -412,10 +412,10 @@ class TraceIndice(object):
node_idx (int)
"""
# get conv input
assert node.kwargs['size'] is None
assert node.kwargs["size"] is None
assert len(get_node_shape(node)) == 4
# assgin index
# assign index
self._assign_indice_as_input(node, node_idx)
self._mark_computation(node, node_idx, [-1, -2])
......@@ -461,7 +461,7 @@ class TraceIndice(object):
nodes_in.append(node_in)
self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assgin_no_change_indice(self, node, idx):
def _assign_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
......@@ -792,7 +792,7 @@ class TraceIndice(object):
self._add_dim(node_idx, i)
dim_from.reverse()
# inheirt indice from current node
# inherit indice from current node
if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1:
if origin_shape[dim_from[0]] == 1:
......@@ -826,7 +826,7 @@ class TraceIndice(object):
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
......@@ -852,7 +852,7 @@ class TraceIndice(object):
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
self._assgin_no_change_indice(node, idx)
self._assign_no_change_indice(node, idx)
elif "new_ones" == node_name:
self._assign_all_indice(node, idx)
elif "flatten" == node_name:
......@@ -876,10 +876,24 @@ class TraceIndice(object):
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
elif any(n == node_name for n in [
"mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
"sin", "cos"
]):
elif any(
n == node_name
for n in [
"mul",
"add",
"sigmoid",
"relu",
"sub",
"truediv",
"pow",
"dropout",
"where",
"tanh",
"exp",
"sin",
"cos",
]
):
self._assign_elementwise_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
......@@ -914,7 +928,7 @@ class TraceIndice(object):
elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx)
elif "identity" == node_name:
self._assgin_no_change_indice(node, idx)
self._assign_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx)
else:
......
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from typing import Any, Dict, List, Union
from torch.fx.node import Node
......@@ -10,7 +10,6 @@ logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
......@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
if (input_node not in nodes and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)):
if (
input_node not in nodes
and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)
):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
if (output_node not in nodes and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)):
if (
output_node not in nodes
and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)
):
output_nodes.append(node)
return input_nodes, output_nodes
......@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
node.meta['fwd_out'][0], int):
elif (
len(node.meta["fwd_out"]) > 0
and isinstance(node.meta["fwd_out"], list)
and isinstance(node.meta["fwd_out"][0], int)
):
out.append(node)
return out
import torch
import torch.nn as nn
__all__ = ['Accelerator']
__all__ = ["Accelerator"]
_supported_devices = [
'cpu',
'cuda',
"cpu",
"cuda",
# To be supported
# 'xpu',
# 'npu',
......@@ -25,21 +24,22 @@ class Accelerator:
def __init__(self, device: str):
self.device = device
assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
assert (
self.device in _supported_devices
), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
def bind(self):
"""
Set the default device for the current process.
"""
if self.device == 'cpu':
if self.device == "cpu":
pass
elif self.device == 'cuda':
elif self.device == "cuda":
# TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed:
# local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
torch.cuda.set_device(torch.device('cuda'))
pass
torch.cuda.set_device(torch.device("cuda"))
else:
raise ValueError(f"Device {self.device} is not supported yet")
......
import warnings
from contextlib import contextmanager
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import torch
import torch.nn as nn
......@@ -8,13 +8,16 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
from .plugin.pp_plugin_base import PipelinePluginBase
__all__ = ['Booster']
__all__ = ["Booster"]
class Booster:
......@@ -22,56 +25,67 @@ class Booster:
Booster is a high-level API for training neural networks. It provides a unified interface for
training with different precision, accelerator, and plugin.
Examples:
>>> colossalai.launch(...)
>>> plugin = GeminiPlugin(stage=3, ...)
>>> booster = Booster(precision='fp16', plugin=plugin)
>>>
>>> model = GPT2()
>>> optimizer = Adam(model.parameters())
>>> dataloader = Dataloader(Dataset)
>>> lr_scheduler = LinearWarmupScheduler()
>>> criterion = GPTLMLoss()
>>>
>>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
>>>
>>> for epoch in range(max_epochs):
>>> for input_ids, attention_mask in dataloader:
>>> outputs = model(input_ids, attention_mask)
>>> loss = criterion(outputs.logits, input_ids)
>>> booster.backward(loss, optimizer)
>>> optimizer.step()
>>> lr_scheduler.step()
>>> optimizer.zero_grad()
```python
# Following is pseudocode
colossalai.launch(...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)
model = GPT2()
optimizer = HybridAdam(model.parameters())
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.logits, input_ids)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
```
Args:
device (str or torch.device): The device to run the training. Default: 'cuda'.
device (str or torch.device): The device to run the training. Default: None.
If plugin is not used or plugin doesn't control the device,
this argument will be set as training device ('cuda' will be used if argument is None).
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
plugin (Plugin): The plugin to run the training. Default: None.
"""
def __init__(self,
device: str = 'cuda',
mixed_precision: Union[MixedPrecision, str] = None,
plugin: Optional[Plugin] = None) -> None:
def __init__(
self,
device: Optional[str] = None,
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
plugin: Optional[Plugin] = None,
) -> None:
if plugin is not None:
assert isinstance(
plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
if device is not None:
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
else:
device = device or "cuda"
self.accelerator = Accelerator(device)
# set precision
if self.plugin and self.plugin.control_precision():
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
if mixed_precision is not None:
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
......@@ -85,7 +99,7 @@ class Booster:
self.mixed_precision = mixed_precision
else:
raise ValueError(
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}."
)
if self.plugin is not None and self.plugin.control_checkpoint_io():
......@@ -96,79 +110,216 @@ class Booster:
def boost(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
Args:
model (nn.Module): The model to be boosted.
optimizer (Optimizer): The optimizer to be boosted.
criterion (Callable): The criterion to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
model (nn.Module): Convert model into a wrapped model for distributive training.
The model might be decorated or partitioned by plugin's strategy after execution of this method.
optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
criterion (Callable, optional): The function that calculates loss. Defaults to None.
dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.
Returns:
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case
pretrained_path = pretrained_utils.get_pretrained_path(model)
# transform model for mixed precision
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler)
model, optimizer, criterion, dataloader, lr_scheduler
)
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
model = self.accelerator.configure(model)
model = self.accelerator.configure_model(model)
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
if pretrained_path:
self.load_model(model, pretrained_path)
# clear pretrained path attr
orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model
pretrained_utils.set_pretrained_path(orig_model, None)
return model, optimizer, criterion, dataloader, lr_scheduler
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
# TODO: implement this method with plugin
"""Execution of backward during training step.
Args:
loss (torch.Tensor): The loss for backpropagation.
optimizer (Optimizer): The optimizer to be updated.
"""
# TODO(frank lee): implement this method with plugin
optimizer.backward(loss)
def execute_pipeline(self,
def execute_pipeline(
self,
data_iter: Iterator,
model: nn.Module,
criterion: Callable[[torch.Tensor], torch.Tensor],
optimizer: Optimizer,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[Optimizer] = None,
return_loss: bool = True,
return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
# TODO: implement this method
# run pipeline forward backward pass
# return loss or outputs if needed
pass
def no_sync(self, model: nn.Module) -> contextmanager:
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
return_outputs: bool = False,
) -> Dict[str, Any]:
"""
Execute forward & backward when utilizing pipeline parallel.
Return loss or Huggingface style model outputs if needed.
Warning: This function is tailored for the scenario of pipeline parallel.
As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
when doing pipeline parallel training with booster, which will cause unexpected errors.
Args:
data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
1. wrap the dataloader to iterator through: iter(dataloader)
2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
Returns:
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
"""
assert isinstance(
self.plugin, PipelinePluginBase
), f"The plugin {self.plugin.__class__.__name__} does not support pipeline."
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
Support torch DDP and Low Level ZeRO-1 for now.
Args:
model (nn.Module): The model to be disabled gradient synchronization, for DDP
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
Returns:
contextmanager: Context to disable gradient synchronization.
"""
assert (
self.plugin is not None
), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync."
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
return self.plugin.no_sync(model, optimizer)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.
Args:
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
"""
self.checkpoint_io.load_model(model, checkpoint, strict)
def save_model(self,
model: nn.Module,
def save_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
prefix: str = None,
shard: bool = False,
size_per_shard: int = 1024):
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
"""Save model to checkpoint.
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
Args:
model (nn.Module or ModelWrapper): A model boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
"""
self.checkpoint_io.save_model(
model,
checkpoint=checkpoint,
shard=shard,
gather_dtensor=gather_dtensor,
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""Load optimizer from checkpoint.
Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
def save_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
) -> None:
"""
Save optimizer to checkpoint.
Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.
Args:
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local file path.
"""
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Load lr scheduler from checkpoint.
Args:
lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local file path.
"""
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
from .bf16 import BF16MixedPrecision
from .fp8 import FP8MixedPrecision
from .fp16_apex import FP16ApexMixedPrecision
from .fp16_naive import FP16NaiveMixedPrecision
from .fp16_torch import FP16TorchMixedPrecision
from .mixed_precision_base import MixedPrecision
__all__ = [
'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision'
"MixedPrecision",
"mixed_precision_factory",
"FP16_Apex_MixedPrecision",
"FP16_Torch_MixedPrecision",
"FP32_MixedPrecision",
"BF16_MixedPrecision",
"FP8_MixedPrecision",
"FP16NaiveMixedPrecision",
]
_mixed_precision_mapping = {
'fp16': FP16TorchMixedPrecision,
'fp16_apex': FP16ApexMixedPrecision,
'bf16': BF16MixedPrecision,
'fp8': FP8MixedPrecision
"fp16": FP16TorchMixedPrecision,
"fp16_apex": FP16ApexMixedPrecision,
"fp16_naive": FP16NaiveMixedPrecision,
"bf16": BF16MixedPrecision,
"fp8": FP8MixedPrecision,
}
......@@ -29,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
return _mixed_precision_mapping[mixed_precision_type]()
else:
raise ValueError(
f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
)
from typing import Any, Optional, Union
import torch
from .mixed_precision_base import MixedPrecision
class FP16ApexMixedPrecision(MixedPrecision):
"""
Precision for mixed precision training in FP16 using apex AMP.
Args:
opt_level(str, optional, default="O1" ): Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above Apex AMP Documentation.
cast_model_type (torch.dtype, optional, default=None): Casts your model’s parameters and buffers to the desired type.
patch_torch_functions (bool, optional, default=None): Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32.
keep_batchnorm_fp32 (bool or str, optional, default=None): To enhance precision and enable cudnn batchnorm (which improves performance), it’s often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16.
master_weights (bool, optional, default=None): Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients.
loss_scale (float or str, optional, default=None): If loss_scale is a float value, use this value as the static (fixed) loss scale. If loss_scale is the string "dynamic", adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically.
cast_model_outputs (torch.dpython:type, optional, default=None): Option to ensure that the outputs of your model(s) are always cast to a particular type regardless of opt_level.
num_losses(int, optional, default=1): Option to tell AMP in advance how many losses/backward passes you plan to use. When used in conjunction with the loss_id argument to `amp.scale_loss`, enables Amp to use a different loss scale per loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple losses/backward passes, but use a single global loss scale for all of them.
verbosity(int, default=1): Set to 0 to suppress Amp-related output.
min_loss_scale(float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored.
max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.
"""
def __init__(
self,
opt_level: Optional[str] = "O1",
cast_model_type: torch.dtype = None,
patch_torch_functions: bool = None,
keep_batchnorm_fp32: Union[bool, str] = None,
master_weights: bool = None,
loss_scale: Union[float, str] = None,
cast_model_outputs: Any = None,
num_losses: Optional[int] = 1,
verbosity: int = 1,
min_loss_scale: float = None,
max_loss_scale: float = 2.0**24,
) -> None:
pass
from .mixed_precision_base import MixedPrecision
class FP16NaiveMixedPrecision(MixedPrecision):
"""
Precision for mixed precision training in FP16 using naive AMP.
Args:
log_num_zeros_in_grad(bool): return number of zeros in the gradients.
initial_scale(int): initial scale of gradient scaler.
growth_factor(int): the growth rate of loss scale.
backoff_factor(float): the decrease rate of loss scale.
hysteresis(int): delay shift in dynamic loss scaling.
max_scale(int): maximum loss scale allowed.
verbose(bool): if set to `True`, will print debug info.
"""
def __init__(
self,
log_num_zeros_in_grad: bool,
initial_scale: int,
growth_factor: int,
backoff_factor: float,
hysteresis: int,
max_scale: int,
verbose: bool = None,
) -> None:
pass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .mixed_precision_base import MixedPrecision
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
class TorchAMPOptimizer(OptimizerWrapper):
......@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper):
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(self,
def __init__(
self,
optim: Optimizer,
init_scale: float = 2.**16,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000) -> None:
growth_interval: int = 2000,
) -> None:
super().__init__(optim)
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
self.scaler = torch.cuda.amp.GradScaler(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval)
growth_interval=growth_interval,
)
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
......@@ -60,12 +64,14 @@ class TorchAMPOptimizer(OptimizerWrapper):
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
def clip_grad_by_norm(self,
def clip_grad_by_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> None:
**kwargs,
) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
......@@ -102,22 +108,29 @@ class FP16TorchMixedPrecision(MixedPrecision):
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(self,
init_scale: float = 2.**16,
def __init__(
self,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000) -> None:
growth_interval: int = 2000,
) -> None:
super().__init__()
self.torch_amp_kwargs = dict(init_scale=init_scale,
self.torch_amp_kwargs = dict(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval)
growth_interval=growth_interval,
)
def configure(self,
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
......
from abc import ABC, abstractmethod
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
......@@ -13,9 +13,11 @@ class MixedPrecision(ABC):
"""
@abstractmethod
def configure(self,
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass
from .gemini_plugin import GeminiPlugin
from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse("1.12.0"):
from .torch_fsdp_plugin import TorchFSDPPlugin
__all__.append("TorchFSDPPlugin")
import random
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from .plugin_base import Plugin
class DPPluginBase(Plugin):
"""This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation."""
def __init__(self) -> None:
super().__init__()
assert (
dist.is_initialized()
), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment"
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
import random
import warnings
from typing import Callable, List, Optional, Tuple, Union
import gc
import logging
import os
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import save_state_dict
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
save_config_file,
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
from .plugin_base import Plugin
from .dp_plugin_base import DPPluginBase
__all__ = ['GeminiPlugin']
__all__ = ["GeminiPlugin"]
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
class GeminiCheckpointIO(GeneralCheckpointIO):
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
"""
Save model to checkpoint but only on master process.
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
"""
Save optimizer to checkpoint but only on master process.
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
def save_sharded_model(
self,
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
"""
Save model to checkpoint but only on master process.
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
):
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
param_groups = optimizer.get_param_groups_for_saving()
torch.save(param_groups, group_file_path)
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=is_master,
use_safetensors=False,
)
class GeminiModel(ModelWrapper):
# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
super().__init__(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
return self.module
assert isinstance(optimizer, GeminiOptimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
class GeminiOptimizer(OptimizerWrapper):
# Load param_groups.
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_param_groups = torch.load(param_group_path)
optimizer.load_param_groups(saved_param_groups)
def __init__(self,
module: GeminiDDP,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
optimizer.optimizer_loading_epilogue()
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiPlugin(Plugin):
class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import GeminiPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = GeminiPlugin()
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
model, train_dataset, optimizer, criterion = ...
plugin = GeminiPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
Defaults to 0.0.
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
......@@ -152,17 +287,24 @@ class GeminiPlugin(Plugin):
def __init__(
self,
device: Optional[torch.device] = None,
placement_policy: str = "cpu",
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
......@@ -173,24 +315,31 @@ class GeminiPlugin(Plugin):
norm_type: float = 2.0,
verbose: bool = False,
) -> None:
assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
self.gemini_config = dict(
device=(device or get_current_device()),
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
search_range_mb=search_range_mb,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
min_chunk_size_m=min_chunk_size_m,
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
self.optim_kwargs = dict(initial_scale=initial_scale,
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
)
self.optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
......@@ -198,7 +347,8 @@ class GeminiPlugin(Plugin):
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
norm_type=norm_type,
)
self.verbose = verbose
def support_no_sync(self) -> bool:
......@@ -208,74 +358,22 @@ class GeminiPlugin(Plugin):
return True
def supported_precisions(self) -> List[str]:
return ['fp16']
return SUPPORTED_PRECISION
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Note:
1. Evaluation datasets should not be passed to this function.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
return ["cuda"]
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
......@@ -287,11 +385,12 @@ class GeminiPlugin(Plugin):
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
)
return model, optimizer, criterion, dataloader, lr_scheduler
......@@ -300,3 +399,6 @@ class GeminiPlugin(Plugin):
def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
import random
from contextlib import nullcontext
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.to(dtype)
return x
class HybridParallelModule(ModelWrapper):
def __init__(
self,
module: Module,
precision: str,
shard_config: ShardConfig,
dp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
assert isinstance(custom_policy, object)
module, self.shared_params = shardformer.optimize(module, policy=custom_policy)
# setting process groups for shared parameters
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
)
# setting mixed_precision
self.mixed_precision = None
if precision == "fp16":
self.mixed_precision = torch.float16
elif precision == "bf16":
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
module = module.cuda()
# setting input type cast when using mixed precision
self.convert_fn = None
if self.mixed_precision is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
# setting ddp configs
if use_ddp:
# convert model to sync bn
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
# wrap the model with PyTorch DDP
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
if self.stage_manager.stage in shared_param:
param = shared_param[self.stage_manager.stage]
dist.all_reduce(param.grad, group=group)
dist.barrier()
def no_sync(self) -> Iterator[None]:
# no sync grads across data parallel
return nullcontext()
def sync_grads(self):
# sync grad across data parallel
if self.dp_group.size() == 1:
return
for p in self.module.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size())
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
# 3. A mapping from integer param_id to param address.
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
if optim is None:
return {}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
packed_group["params"] = []
for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group["params"].append(param_id)
param_info["param2id"][id(param)] = param_id
param_info["id2param"][param_id] = id(param)
param_info["param2shape"][id(param)] = original_shape
param_info["param_groups"].append(packed_group)
start_index += len(group["params"])
return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module):
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group["params"] if p in model_params]
new_param_groups.append({**group, "params": params})
optim.__setstate__({"param_groups": new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim)
def update_master_params(self, model: Module):
pass
def get_working_to_master_map(self):
return None
def get_master_to_working_map(self):
return None
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
self,
optim: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = "fp16",
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(
optim,
precision,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
max_norm,
)
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None,
):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(
optimizer,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
clip_grad_norm,
verbose,
reduce_bucket_size,
communication_dtype,
overlap_communication,
partition_grad,
cpu_offload,
dp_process_group,
tp_process_group,
forced_dtype,
)
class HybridParallelPlugin(PipelinePluginBase):
"""
Plugin for Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
model, train_dataset, optimizer, criterion = ...
plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
```
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
"""
def __init__(
self,
tp_size: int,
pp_size: int,
precision: str = "fp16",
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
custom_policy: Policy = None,
) -> None:
super().__init__()
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
)
self.ddp_config = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.zero_config = dict(
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
)
self.max_norm = max_norm
@property
def enable_pipeline_parallelism(self) -> bool:
return self.pp_size > 1
def supported_devices(self) -> List[str]:
return ["cuda"]
def supported_precisions(self) -> List[str]:
return ["fp16", "bf16", "fp32"]
def control_device(self) -> bool:
return True
def control_precision(self) -> bool:
return True
def support_no_sync(self) -> bool:
return False
def control_checkpoint_io(self) -> bool:
return True
def configure(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config,
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config,
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(
self,
data_iter: Iterator,
model: HybridParallelModule,
criterion: Callable[[Any, Any], torch.Tensor],
optimizer: Optional[
Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer]
] = None,
return_loss: bool = True,
return_outputs: bool = False,
) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
# return loss or outputs if needed
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
model.sync_shared_params()
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
else:
model.sync_grads()
return outputs
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError
import random
import warnings
from typing import Callable, List, Optional, Tuple, Union
import logging
import os
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import LowLevelZeroOptimizer
from .plugin_base import Plugin
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
__all__ = ['LowLevelZeroPlugin']
__all__ = ["LowLevelZeroPlugin"]
def _convert_to_fp16(x):
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.half()
return x.to(dtype)
return x
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class LowLevelZeroModel(ModelWrapper):
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.convert_inputs = (precision == 'fp16')
module = zero_model_wrapper(module, zero_stage=stage)
if precision == 'fp16':
module = module.half()
self.dtype = None
if precision == "fp16":
self.dtype = torch.float16
elif precision == "bf16":
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs):
if self.convert_inputs:
args = tree_map(_convert_to_fp16, args)
kwargs = tree_map(_convert_to_fp16, kwargs)
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(self,
module: nn.Module,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
"""Save optimizer to checkpoint but only on master process.
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')
Args:
optimizer (OptimizerWrapper): Optimizer to save state_dict
checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024,
):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file.
class LowLevelZeroPlugin(Plugin):
Args:
optimizer (OptimizerWrapper): Optimizer to load state_dict
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
optimizer = optimizer.unwrap()
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory."
)
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
padding_size = (
self.coordinator.world_size - v.numel() % self.coordinator.world_size
) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()
def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
class LowLevelZeroPlugin(DPPluginBase):
"""
Plugin for low level zero.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import LowLevelZeroPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = LowLevelZeroPlugin()
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
model, train_dataset, optimizer, criterion = ...
plugin = LowLevelZeroPlugin()
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
strage (int, optional): ZeRO stage. Defaults to 1.
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
stage (int, optional): ZeRO stage. Defaults to 1.
precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
......@@ -126,7 +248,7 @@ class LowLevelZeroPlugin(Plugin):
def __init__(
self,
stage: int = 1,
precision: str = 'fp16',
precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
......@@ -142,113 +264,64 @@ class LowLevelZeroPlugin(Plugin):
cpu_offload: bool = False,
verbose: bool = False,
) -> None:
assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
self.stage = stage
self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload)
self.optim_kwargs = dict(initial_scale=initial_scale,
self.zero_optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
clip_grad_norm=max_norm,
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(stage == 2),
)
self.verbose = verbose
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
def support_no_sync(self) -> bool:
return False
return self.stage == 1
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
return ['fp16', 'fp32']
return SUPPORTED_PRECISION
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Note:
1. Evaluation datasets should not be passed to this function.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
return ["cuda"]
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)
model = LowLevelZeroModel(model, self.precision)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
......@@ -257,3 +330,7 @@ class LowLevelZeroPlugin(Plugin):
def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.optim.no_sync()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment