Unverified Commit 813d42cc authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Group offloading improvements (#11094)

update
parent b4d7e9c6
...@@ -83,7 +83,10 @@ class ModuleGroup: ...@@ -83,7 +83,10 @@ class ModuleGroup:
with context: with context:
for group_module in self.modules: for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking) for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None: if self.parameters is not None:
for param in self.parameters: for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
...@@ -98,6 +101,12 @@ class ModuleGroup: ...@@ -98,6 +101,12 @@ class ModuleGroup:
for group_module in self.modules: for group_module in self.modules:
for param in group_module.parameters(): for param in group_module.parameters():
param.data = self.cpu_param_dict[param] param.data = self.cpu_param_dict[param]
if self.parameters is not None:
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
else: else:
for group_module in self.modules: for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking) group_module.to(self.offload_device, non_blocking=self.non_blocking)
...@@ -387,9 +396,7 @@ def _apply_group_offloading_block_level( ...@@ -387,9 +396,7 @@ def _apply_group_offloading_block_level(
# Create a pinned CPU parameter dict for async data transfer if streams are to be used # Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None cpu_param_dict = None
if stream is not None: if stream is not None:
for param in module.parameters(): cpu_param_dict = _get_pinned_cpu_param_dict(module)
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# Create module groups for ModuleList and Sequential blocks # Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set() modules_with_group_offloading = set()
...@@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level( ...@@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level(
# Create a pinned CPU parameter dict for async data transfer if streams are to be used # Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None cpu_param_dict = None
if stream is not None: if stream is not None:
for param in module.parameters(): cpu_param_dict = _get_pinned_cpu_param_dict(module)
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# Create module groups for leaf modules and apply group offloading hooks # Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set() modules_with_group_offloading = set()
...@@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook( ...@@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook(
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
cpu_param_dict = {}
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict[param] = param.data
for buffer in module.buffers():
buffer.data = buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = buffer.data
return cpu_param_dict
def _gather_parameters_with_no_group_offloading_parent( def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str] module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]: ) -> List[torch.nn.Parameter]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment