Unverified Commit bca0c49a authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[zero] use colo model data api in optimv2 (#511)

parent 9330be0f
...@@ -15,8 +15,8 @@ from torch import Tensor ...@@ -15,8 +15,8 @@ from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from ._utils import has_inf_or_nan from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
class OptimState(Enum): class OptimState(Enum):
...@@ -161,7 +161,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -161,7 +161,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16 # Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16) # TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.col_attr.sharded_data_tensor.copy_payload(p.data) colo_model_data_tensor_move(p, p.col_attr.sharded_data_tensor)
if not is_param_sharded: if not is_param_sharded:
# We gather full fp16 param here # We gather full fp16 param here
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment