Commit ea6905a8 authored by ver217's avatar ver217
Browse files

free param.grad

parent 9506a8be
import functools
from ast import Try
from asyncio.log import logger
from collections import OrderedDict
from typing import Any, Optional
......@@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
......@@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module):
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
if self.gradient_postdivide_factor > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment