Unverified Commit df635641 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[gemini] support amp o3 for gemini (#4872)

* [gemini] support no reuse fp16 chunk

* [gemini] support no master weight for optim

* [gemini] support no master weight for gemini ddp

* [test] update gemini tests

* [test] update gemini tests

* [plugin] update gemini plugin

* [test] fix gemini checkpointio test

* [test] fix gemini checkpoint io
parent c1fab951
...@@ -97,7 +97,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -97,7 +97,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path) index_file = CheckpointIndexFile(checkpoint_path)
...@@ -257,6 +257,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -257,6 +257,7 @@ class GeminiPlugin(DPPluginBase):
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
master_weights (bool, optional): master weights. Defaults to True.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False. pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
...@@ -296,6 +297,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -296,6 +297,7 @@ class GeminiPlugin(DPPluginBase):
warmup_non_model_data_ratio: float = 0.8, # only for auto placement warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16", precision: str = "fp16",
master_weights: bool = True,
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False, strict_ddp_mode: bool = False,
...@@ -334,6 +336,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -334,6 +336,7 @@ class GeminiPlugin(DPPluginBase):
min_chunk_size_m=min_chunk_size_m, min_chunk_size_m=min_chunk_size_m,
memstats=memstats, memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision], mixed_precision=PRECISION_STR_TO_DTYPE[precision],
master_weights=master_weights,
) )
self.zero_optim_config = dict( self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio, gpu_margin_mem_ratio=gpu_margin_mem_ratio,
......
...@@ -132,9 +132,6 @@ class CPUAdam(NVMeOptimizer): ...@@ -132,9 +132,6 @@ class CPUAdam(NVMeOptimizer):
target_device = p.device target_device = p.device
if len(state) == 0: if len(state) == 0:
state["step"] = 0 state["step"] = 0
# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
# gradient momentums # gradient momentums
state["exp_avg"] = torch.zeros_like(p, device=target_device) state["exp_avg"] = torch.zeros_like(p, device=target_device)
# gradient variances # gradient variances
...@@ -149,7 +146,8 @@ class CPUAdam(NVMeOptimizer): ...@@ -149,7 +146,8 @@ class CPUAdam(NVMeOptimizer):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq") self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16: # FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
# cpu adam kernel does not support bf16 now # cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]
......
...@@ -108,9 +108,6 @@ class HybridAdam(CPUAdam): ...@@ -108,9 +108,6 @@ class HybridAdam(CPUAdam):
target_device = p.device target_device = p.device
if len(state) == 0: if len(state) == 0:
state["step"] = 0 state["step"] = 0
# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
# gradient momentums # gradient momentums
state["exp_avg"] = torch.zeros_like(p, device=target_device) state["exp_avg"] = torch.zeros_like(p, device=target_device)
# gradient variances # gradient variances
...@@ -125,7 +122,8 @@ class HybridAdam(CPUAdam): ...@@ -125,7 +122,8 @@ class HybridAdam(CPUAdam):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq") self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16: # FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
# cpu adam kernel does not support bf16 now # cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]
......
...@@ -40,7 +40,7 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): ...@@ -40,7 +40,7 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False):
assert len(list(d1.keys())) == len( assert len(list(d1.keys())) == len(
list(d2.keys()) list(d2.keys())
), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
...@@ -58,6 +58,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool ...@@ -58,6 +58,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
if not ignore_device: if not ignore_device:
v1_i = v1_i.to("cpu") v1_i = v1_i.to("cpu")
v2_i = v2_i.to("cpu") v2_i = v2_i.to("cpu")
if ignore_dtype:
v1_i = v1_i.to(v2_i.dtype)
assert_close_loose(v1_i, v2_i) assert_close_loose(v1_i, v2_i)
elif isinstance(v1_i, dict): elif isinstance(v1_i, dict):
assert isinstance(v2_i, dict) assert isinstance(v2_i, dict)
...@@ -69,6 +71,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool ...@@ -69,6 +71,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
if not ignore_device: if not ignore_device:
v1 = v1.to("cpu") v1 = v1.to("cpu")
v2 = v2.to("cpu") v2 = v2.to("cpu")
if ignore_dtype:
v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2) assert_close_loose(v1, v2)
else: else:
assert v1 == v2, f"{v1} not equals to {v2}" assert v1 == v2, f"{v1} not equals to {v2}"
......
...@@ -160,6 +160,8 @@ class Chunk: ...@@ -160,6 +160,8 @@ class Chunk:
self.l2_norm_flag = False self.l2_norm_flag = False
self.l2_norm = None self.l2_norm = None
self.grad_chunk = None
@property @property
def memory_usage(self) -> Dict[str, int]: def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0 cuda_memory = 0
...@@ -414,7 +416,9 @@ class Chunk: ...@@ -414,7 +416,9 @@ class Chunk:
return return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: def copy_tensor_to_chunk_slice(
self, tensor: torch.Tensor, data_slice: torch.Tensor, update_ptr: bool = True
) -> None:
""" """
Copy data slice to the memory space indexed by the input tensor in the chunk. Copy data slice to the memory space indexed by the input tensor in the chunk.
...@@ -427,7 +431,8 @@ class Chunk: ...@@ -427,7 +431,8 @@ class Chunk:
tensor_info = self.tensors_info[tensor] tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten()) self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) if update_ptr:
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
def get_valid_length(self) -> int: def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.""" """Get the valid length of the chunk's payload."""
...@@ -577,3 +582,46 @@ class Chunk: ...@@ -577,3 +582,46 @@ class Chunk:
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
return "".join(output) return "".join(output)
def init_grad_chunk(self) -> "Chunk":
"""Init grad chunk. This should be called in grad handler.
Returns:
Chunk: Grad chunk
"""
if self.grad_chunk is None:
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
process_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size
grad_chunk.tensor_state_cnter[TensorState.HOLD] = self.num_tensors
for tensor, state in self.tensors_info.items():
grad_chunk.tensors_info[tensor] = TensorInfo(TensorState.HOLD, state.offset, state.end)
grad_chunk.valid_end = self.valid_end
if grad_chunk.chunk_temp.device.type == "cpu":
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device())
else:
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
grad_chunk.chunk_temp = None
if grad_chunk.pin_memory:
grad_chunk.cpu_shard = torch.empty(
grad_chunk.shard_size, dtype=grad_chunk.dtype, pin_memory=grad_chunk.pin_memory
)
self.grad_chunk = grad_chunk
else:
# grad chunk is initialized, just reallocate cuda global chunk
self.grad_chunk.cuda_shard = None
self.grad_chunk.is_gathered = True
alloc_storage(self.grad_chunk.cuda_global_chunk)
return self.grad_chunk
...@@ -245,3 +245,13 @@ class ChunkManager: ...@@ -245,3 +245,13 @@ class ChunkManager:
chunk.release_chunk() chunk.release_chunk()
self.accessed_chunks.remove(chunk) self.accessed_chunks.remove(chunk)
self.accessed_mem -= chunk.chunk_mem self.accessed_mem -= chunk.chunk_mem
def init_grad_chunk(self, chunk: Chunk) -> Chunk:
if chunk.grad_chunk is not None:
self.__sub_memory_usage(chunk.grad_chunk.memory_usage)
grad_chunk = chunk.init_grad_chunk()
self.__add_memory_usage(grad_chunk.memory_usage)
if grad_chunk not in self.accessed_chunks:
self.accessed_chunks.add(grad_chunk)
self.accessed_mem += grad_chunk.chunk_mem
return grad_chunk
...@@ -74,6 +74,7 @@ class GeminiDDP(ModelWrapper): ...@@ -74,6 +74,7 @@ class GeminiDDP(ModelWrapper):
mixed_precision: torch.dtype = torch.float16, mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None, process_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats memstats: Optional[MemStats] = None, # genimi memory stats
master_weights: bool = True,
verbose: bool = False, verbose: bool = False,
) -> None: ) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16) assert mixed_precision in (torch.float16, torch.bfloat16)
...@@ -115,6 +116,9 @@ class GeminiDDP(ModelWrapper): ...@@ -115,6 +116,9 @@ class GeminiDDP(ModelWrapper):
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group() self.dp_process_group = process_group or _get_default_group()
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
self._logger = get_dist_logger() self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_: if self.gemini_manager._premade_memstats_:
...@@ -321,20 +325,37 @@ class GeminiDDP(ModelWrapper): ...@@ -321,20 +325,37 @@ class GeminiDDP(ModelWrapper):
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter." "Some unsupported torch function is operated upon this parameter."
) )
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) grad_chunk = chunk
chunk.copy_tensor_to_chunk_slice(p, grad) if not self.reuse_fp16_chunk:
reduced = self.chunk_manager.reduce_chunk(chunk) grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
# fp16 param chunk: hold after bwd -> ready for reduce -> hold
chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
chunk.tensor_trans_state(p, TensorState.HOLD)
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
if reduced: if reduced:
if chunk.is_gathered: if not self.reuse_fp16_chunk:
chunk.cuda_global_chunk.div_(chunk.pg_size) if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
else:
self.chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
else: else:
chunk.cuda_shard.div_(chunk.pg_size) grad_chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements # check overflow elements
self.overflow_counter += chunk.has_inf_or_nan self.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping # record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag: if chunk.l2_norm_flag:
chunk.set_l2_norm() grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not self.master_weights:
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None: def zero_grad(self, set_to_none: bool = False) -> None:
...@@ -344,9 +365,7 @@ class GeminiDDP(ModelWrapper): ...@@ -344,9 +365,7 @@ class GeminiDDP(ModelWrapper):
for tensor in chunk.get_tensors(): for tensor in chunk.get_tensors():
self.grads_device[tensor] = device self.grads_device[tensor] = device
def state_dict( def state_dict(self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True):
self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16
):
"""Returns a dictionary containing a whole state of the module. """Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Both parameters and persistent buffers (e.g. running averages) are included.
...@@ -365,7 +384,7 @@ class GeminiDDP(ModelWrapper): ...@@ -365,7 +384,7 @@ class GeminiDDP(ModelWrapper):
destination = OrderedDict() destination = OrderedDict()
destination._metadata = OrderedDict() destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype) self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
for hook in self._state_dict_hooks.values(): for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata) hook_result = hook(self, destination, prefix, local_metadata)
...@@ -373,7 +392,7 @@ class GeminiDDP(ModelWrapper): ...@@ -373,7 +392,7 @@ class GeminiDDP(ModelWrapper):
destination = hook_result destination = hook_result
return destination return destination
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict: def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
""" """
get gathered chunk content. get gathered chunk content.
...@@ -386,9 +405,8 @@ class GeminiDDP(ModelWrapper): ...@@ -386,9 +405,8 @@ class GeminiDDP(ModelWrapper):
""" """
# save parameters # save parameters
chunk_to_save_data = dict() chunk_to_save_data = dict()
temp_chunk = get_temp_total_chunk_on_cuda(chunk) temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
if torch.is_floating_point(temp_chunk):
temp_chunk = temp_chunk.to(dtype)
for tensor, tensor_info in chunk.tensors_info.items(): for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0]) record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
...@@ -401,9 +419,7 @@ class GeminiDDP(ModelWrapper): ...@@ -401,9 +419,7 @@ class GeminiDDP(ModelWrapper):
del temp_chunk del temp_chunk
return chunk_to_save_data return chunk_to_save_data
def _get_param_to_save_data( def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype
) -> Dict:
""" """
get param content from chunks. get param content from chunks.
...@@ -418,10 +434,10 @@ class GeminiDDP(ModelWrapper): ...@@ -418,10 +434,10 @@ class GeminiDDP(ModelWrapper):
param_to_save_data = dict() param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list) chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list: for chunk in chunk_list:
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
return param_to_save_data return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16): def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`. submodule in :meth:`~torch.nn.Module.state_dict`.
...@@ -438,14 +454,18 @@ class GeminiDDP(ModelWrapper): ...@@ -438,14 +454,18 @@ class GeminiDDP(ModelWrapper):
# get copies of fp32 parameters in CPU # get copies of fp32 parameters in CPU
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16 # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype) params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
# get the mapping between copies and fp16 parameters # get the mapping between copies and fp16 parameters
p_mapping = dict() p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params): if self.reuse_fp16_chunk:
name = self.param2name[p] for p, fp32_p in zip(self.fp16_params, self.fp32_params):
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) name = self.param2name[p]
record_parameter = param_to_save_data[fp32_p] assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
p_mapping[p] = record_parameter record_parameter = param_to_save_data[fp32_p]
p_mapping[p] = record_parameter
else:
p_mapping = param_to_save_data
for name, param in self.name2param.items(): for name, param in self.name2param.items():
if param is not None: if param is not None:
if is_ddp_ignored(param): if is_ddp_ignored(param):
...@@ -593,7 +613,7 @@ class GeminiDDP(ModelWrapper): ...@@ -593,7 +613,7 @@ class GeminiDDP(ModelWrapper):
elif strict: elif strict:
missing_keys.append(state_key) missing_keys.append(state_key)
def load_fp32_parameter(chunk_slice, data): def load_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten()) chunk_slice.copy_(data.flatten())
for name, param in self.named_parameters(): for name, param in self.named_parameters():
...@@ -607,14 +627,15 @@ class GeminiDDP(ModelWrapper): ...@@ -607,14 +627,15 @@ class GeminiDDP(ModelWrapper):
name = self.param2name[p] name = self.param2name[p]
fp32_to_name[fp32_p] = name fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params) params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
chunk_list = self.chunk_manager.get_chunks(params_to_load)
for chunk in chunk_list: for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk) temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items(): for tensor, tensor_info in chunk.tensors_info.items():
parameter_name = fp32_to_name[tensor] parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) load(parameter_name, tensor, partial(load_parameter, parameter_slice))
if chunk.is_gathered: if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk) chunk.cuda_global_chunk.copy_(temp_chunk)
...@@ -624,11 +645,11 @@ class GeminiDDP(ModelWrapper): ...@@ -624,11 +645,11 @@ class GeminiDDP(ModelWrapper):
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end])
del temp_chunk del temp_chunk
if self.reuse_fp16_chunk:
for chunk_32 in chunk_list: for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None assert chunk_16 is not None
chunk_16.payload.copy_(chunk_32.payload) chunk_16.payload.copy_(chunk_32.payload)
for name, buf in persistent_buffers.items(): for name, buf in persistent_buffers.items():
if buf is not None: if buf is not None:
...@@ -668,12 +689,9 @@ class GeminiDDP(ModelWrapper): ...@@ -668,12 +689,9 @@ class GeminiDDP(ModelWrapper):
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
continue continue
# create a fp32 parameter
fp32_p = p.data.float()
# create a fp16 parameter # create a fp16 parameter
p.data = p.data.to(self.mixed_precision) p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter
# register the fp16 parameter and fp32 parameter in the chunk manager
self.chunk_manager.register_tensor( self.chunk_manager.register_tensor(
tensor=p, tensor=p,
group_type="fp16_param", group_type="fp16_param",
...@@ -682,22 +700,27 @@ class GeminiDDP(ModelWrapper): ...@@ -682,22 +700,27 @@ class GeminiDDP(ModelWrapper):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
self.fp16_params.append(p) self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
if self.master_weights:
# create a fp32 parameter
fp32_p = p.data.float()
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
self.fp32_params.append(fp32_p)
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device) self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
# move master weights to corresponding device and setup paired chunks # move master weights to corresponding device and setup paired chunks
# if no master weights, fp32_params should be empty and this loop will be skipped
for p, fp32_p in zip(self.fp16_params, self.fp32_params): for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32 = self.chunk_manager.get_chunk(fp32_p)
...@@ -734,7 +757,6 @@ class GeminiDDP(ModelWrapper): ...@@ -734,7 +757,6 @@ class GeminiDDP(ModelWrapper):
keep_vars: bool = False, keep_vars: bool = False,
max_shard_size: int = 1024, max_shard_size: int = 1024,
only_rank_0: bool = True, only_rank_0: bool = True,
dtype: torch.dtype = torch.float16,
) -> Iterator[Tuple[OrderedDict, int]]: ) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
...@@ -769,11 +791,11 @@ class GeminiDDP(ModelWrapper): ...@@ -769,11 +791,11 @@ class GeminiDDP(ModelWrapper):
gathered_param = param if keep_vars else param.detach() gathered_param = param if keep_vars else param.detach()
else: else:
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16 # as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
fp32_param = fp16_to_fp32[param] param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
if fp32_param not in gathered_param_buffer: if param_to_save not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(fp32_param) chunk = self.chunk_manager.get_chunk(param_to_save)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
gathered_param = gathered_param_buffer.pop(fp32_param) gathered_param = gathered_param_buffer.pop(param_to_save)
block, block_size = sharder.append_param(prefix + name, gathered_param) block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None: if block is not None:
......
...@@ -105,7 +105,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -105,7 +105,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.gemini_manager = module.gemini_manager self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.param_to_chunk16: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set() self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = max_norm > 0.0 self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm self.max_norm = max_norm
...@@ -130,7 +130,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -130,7 +130,7 @@ class GeminiOptimizer(OptimizerWrapper):
else: else:
ddp_param_list.append(param) ddp_param_list.append(param)
for p, fp32_p in zip(ddp_param_list, module.fp32_params): for p in ddp_param_list:
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set: if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag chunk_16.l2_norm_flag = self.clipping_flag
...@@ -174,13 +174,15 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -174,13 +174,15 @@ class GeminiOptimizer(OptimizerWrapper):
def _set_grad_ptr(self): def _set_grad_ptr(self):
for group in self.param_groups: for group in self.param_groups:
for fake_param in group["params"]: for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param] chunk16 = self.param_to_chunk16[fake_param]
begin, end = self.param_to_range[fake_param] begin, end = self.param_to_range[fake_param]
chunk16 = chunk32.paired_chunk
fake_param.data = chunk16.payload[begin:end] grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
fake_param.data = grad_chunk16.payload[begin:end]
fake_param.grad = fake_param.data fake_param.grad = fake_param.data
fake_param.data = chunk32.payload[begin:end]
to_update_chunk = chunk16.paired_chunk if self.module.master_weights else chunk16
fake_param.data = to_update_chunk.payload[begin:end]
def _update_fp16_params(self): def _update_fp16_params(self):
none_tensor = torch.empty([0]) none_tensor = torch.empty([0])
...@@ -194,23 +196,25 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -194,23 +196,25 @@ class GeminiOptimizer(OptimizerWrapper):
def _clear_global_norm(self) -> None: def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set: for c16 in self.chunk16_set:
c16.l2_norm = None grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk.l2_norm = None
def _calc_global_norm(self) -> float: def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0 norm_sqr: float = 0.0
group_to_norm = dict() group_to_norm = dict()
for c16 in self.chunk16_set: for c16 in self.chunk16_set:
assert c16.l2_norm is not None grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
assert grad_chunk.l2_norm is not None
if c16.is_gathered: if grad_chunk.is_gathered:
norm_sqr += c16.l2_norm norm_sqr += grad_chunk.l2_norm
else: else:
# this chunk is sharded, use communication to collect total norm # this chunk is sharded, use communication to collect total norm
if c16.torch_pg not in group_to_norm: if grad_chunk.torch_pg not in group_to_norm:
group_to_norm[c16.torch_pg] = 0.0 group_to_norm[grad_chunk.torch_pg] = 0.0
group_to_norm[c16.torch_pg] += c16.l2_norm group_to_norm[grad_chunk.torch_pg] += grad_chunk.l2_norm
c16.l2_norm = None # clear l2 norm grad_chunk.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
for group, part_norm in group_to_norm.items(): for group, part_norm in group_to_norm.items():
...@@ -237,7 +241,8 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -237,7 +241,8 @@ class GeminiOptimizer(OptimizerWrapper):
return self.optim.zero_grad(set_to_none=True) return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
self._maybe_move_fp32_params() if self.module.master_weights:
self._maybe_move_fp32_params()
self._set_grad_ptr() self._set_grad_ptr()
if self.mix_precision_mixin.should_skip_step(): if self.mix_precision_mixin.should_skip_step():
...@@ -245,7 +250,8 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -245,7 +250,8 @@ class GeminiOptimizer(OptimizerWrapper):
self._logger.info(f"Found overflow. Skip step") self._logger.info(f"Found overflow. Skip step")
self._clear_global_norm() # clear recorded norm self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients self.zero_grad() # reset all gradients
self._update_fp16_params() if self.module.reuse_fp16_chunk:
self._update_fp16_params()
return return
# get combined scale. combined scale = loss scale * clipping norm # get combined scale. combined scale = loss scale * clipping norm
...@@ -255,7 +261,8 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -255,7 +261,8 @@ class GeminiOptimizer(OptimizerWrapper):
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states() self._register_states()
self.zero_grad() self.zero_grad()
self._update_fp16_params() if self.module.master_weights:
self._update_fp16_params()
return ret return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
...@@ -282,8 +289,8 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -282,8 +289,8 @@ class GeminiOptimizer(OptimizerWrapper):
for group in self.param_groups: for group in self.param_groups:
for fake_param in group["params"]: for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param] chunk16 = self.param_to_chunk16[fake_param]
chunk16 = chunk32.paired_chunk chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda": if chunk32.device_type == "cuda":
continue continue
...@@ -297,7 +304,8 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -297,7 +304,8 @@ class GeminiOptimizer(OptimizerWrapper):
for group in self.param_groups: for group in self.param_groups:
for fake_param in group["params"]: for fake_param in group["params"]:
chunk32 = self.param_to_chunk32[fake_param] chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda": if chunk32.device_type == "cuda":
state = self.optim.state[fake_param] state = self.optim.state[fake_param]
for k, v in state.items(): for k, v in state.items():
...@@ -341,7 +349,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -341,7 +349,7 @@ class GeminiOptimizer(OptimizerWrapper):
continue continue
grad_device = self.module.grads_device[param] grad_device = self.module.grads_device[param]
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_chunk32[fake_param] = chunk16.paired_chunk self.param_to_chunk16[fake_param] = chunk16
self.param_to_range[fake_param] = range_pair self.param_to_range[fake_param] = range_pair
self.id_to_fake_params[param_id] = fake_param self.id_to_fake_params[param_id] = fake_param
fake_params_list.append(fake_param) fake_params_list.append(fake_param)
...@@ -366,7 +374,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -366,7 +374,7 @@ class GeminiOptimizer(OptimizerWrapper):
if param_id not in self.id_to_fake_params: if param_id not in self.id_to_fake_params:
return -1, -1, -1 return -1, -1, -1
fake_param = self.id_to_fake_params[param_id] fake_param = self.id_to_fake_params[param_id]
chunk = self.param_to_chunk32[fake_param].paired_chunk chunk = self.param_to_chunk16[fake_param]
param = self.id_to_real_params[param_id] param = self.id_to_real_params[param_id]
param_info = chunk.tensors_info[param] param_info = chunk.tensors_info[param]
......
...@@ -11,7 +11,7 @@ from colossalai.utils import get_current_device ...@@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
from .chunk import Chunk from .chunk import Chunk
def get_temp_total_chunk_on_cuda(chunk: Chunk): def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
if chunk.is_gathered: if chunk.is_gathered:
return chunk.cuda_global_chunk return chunk.cuda_global_chunk
...@@ -20,7 +20,9 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): ...@@ -20,7 +20,9 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
else: else:
shard_temp = chunk.cpu_shard.to(get_current_device()) shard_temp = chunk.cpu_shard.to(get_current_device())
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) shard_temp = shard_temp.to(dtype)
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device())
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
......
...@@ -58,9 +58,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b ...@@ -58,9 +58,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
dist.barrier() dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal( check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False
)
@clear_cache_before_run() @clear_cache_before_run()
...@@ -100,7 +98,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha ...@@ -100,7 +98,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
dist.barrier() dist.barrier()
booster.load_model(new_model, model_ckpt_path) booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal( check_state_dict_equal(
...@@ -136,7 +136,7 @@ def exam_lazy_from_pretrained(): ...@@ -136,7 +136,7 @@ def exam_lazy_from_pretrained():
booster.save_model(model, save_path, shard=False) booster.save_model(model, save_path, shard=False)
dist.barrier() dist.barrier()
state_dict = torch.load(save_path, map_location="cpu") state_dict = torch.load(save_path, map_location="cpu")
check_state_dict_equal(state_dict, orig_state_dict, False) check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
......
...@@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): ...@@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
# Add prefix to get aligned with pytorch parameter names. # Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal( check_state_dict_equal(
model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), model.state_dict(only_rank_0=False, prefix="module.module."),
new_model.state_dict(), new_model.state_dict(),
False, False,
ignore_dtype=True,
) )
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
...@@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): ...@@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
# Add prefix to get aligned with pytorch parameter names. # Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal( check_state_dict_equal(
new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), new_model.state_dict(only_rank_0=False, prefix="module.module."),
model.state_dict(), model.state_dict(),
False, False,
ignore_dtype=True,
) )
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
......
...@@ -27,6 +27,8 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): ...@@ -27,6 +27,8 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()] param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list) chunk_list = chunk_manager.get_chunks(param_list)
if not model.reuse_fp16_chunk:
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
for chunk in chunk_list: for chunk in chunk_list:
chunk_manager.access_chunk(chunk) chunk_manager.access_chunk(chunk)
...@@ -36,13 +38,15 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): ...@@ -36,13 +38,15 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gather", [False, True]) @parameterize("keep_gather", [False, True])
@parameterize("model_name", ["gpt2", "bert", "albert"]) @parameterize("model_name", ["gpt2", "bert"])
@parameterize("use_grad_checkpoint", [False, True]) @parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
def exam_gpt_fwd_bwd( def exam_gpt_fwd_bwd(
placement_config, placement_config,
keep_gather, keep_gather,
model_name: str, model_name: str,
use_grad_checkpoint: bool = False, use_grad_checkpoint: bool = False,
master_weights: bool = True,
): ):
init_device = get_current_device() init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
...@@ -60,12 +64,14 @@ def exam_gpt_fwd_bwd( ...@@ -60,12 +64,14 @@ def exam_gpt_fwd_bwd(
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gather config_dict[world_size]["keep_gathered"] = keep_gather
model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config) model = GeminiDDP(
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
)
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
rank = dist.get_rank() rank = dist.get_rank()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1) amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[rank]) torch_model = DDP(torch_model, device_ids=[rank])
...@@ -106,4 +112,4 @@ def test_gpt(world_size): ...@@ -106,4 +112,4 @@ def test_gpt(world_size):
if __name__ == "__main__": if __name__ == "__main__":
test_gpt(4) test_gpt(1)
...@@ -78,7 +78,11 @@ def exam_grad_clipping(placement_config, model_name: str): ...@@ -78,7 +78,11 @@ def exam_grad_clipping(placement_config, model_name: str):
init_device = None init_device = None
model = GeminiDDP( model = GeminiDDP(
model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config model,
chunk_config_dict=config_dict,
chunk_init_device=init_device,
pin_memory=True,
**placement_config,
) )
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
......
...@@ -44,7 +44,7 @@ BF16_IGNORED_KEYS = [ ...@@ -44,7 +44,7 @@ BF16_IGNORED_KEYS = [
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict() torch_dict = torch_model.state_dict()
for key, value in torch_dict.items(): for key, value in torch_dict.items():
......
...@@ -27,7 +27,8 @@ def ignore_the_first_parameter(model: torch.nn.Module): ...@@ -27,7 +27,8 @@ def ignore_the_first_parameter(model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False]) @parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["gpt2", "bert"])
def exam_state_dict(placement_config, keep_gathered, model_name: str): @parameterize("master_weights", [False, True])
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -42,7 +43,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): ...@@ -42,7 +43,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered config_dict[world_size]["keep_gathered"] = keep_gathered
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
model.train() model.train()
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
...@@ -57,7 +58,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): ...@@ -57,7 +58,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False]) @parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["gpt2", "bert"])
def exam_load_state_dict(placement_config, keep_gathered, model_name: str): @parameterize("master_weights", [False, True])
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -72,7 +74,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): ...@@ -72,7 +74,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered config_dict[world_size]["keep_gathered"] = keep_gathered
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
torch_dict = torch_model.state_dict() torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False) model.load_state_dict(torch_dict, strict=False)
...@@ -86,7 +88,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): ...@@ -86,7 +88,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["gpt2", "bert"])
def exam_state_dict_shard(placement_config, model_name: str): @parameterize("master_weights", [False, True])
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -95,7 +98,7 @@ def exam_state_dict_shard(placement_config, model_name: str): ...@@ -95,7 +98,7 @@ def exam_state_dict_shard(placement_config, model_name: str):
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
model = GeminiDDP(model, config_dict, **placement_config) model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
model.train() model.train()
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment