Commit ea6905a8 authored by ver217's avatar ver217
Browse files

free param.grad

parent 9506a8be
import functools import functools
from ast import Try from asyncio.log import logger
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional from typing import Any, Optional
...@@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2 ...@@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor) get_gradient_predivide_factor)
...@@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module): ...@@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module):
else: else:
self._reduce_scatter_callback(param, new_grad) self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream) orig_grad_data.record_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
if self.gradient_postdivide_factor > 1: if self.gradient_postdivide_factor > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment