Commit 272ebfb5 authored by Jiarui Fang's avatar Jiarui Fang Committed by Frank Lee
Browse files

[bug] shard param during initializing the ShardedModelV2 (#381)

parent 8c18eb09
......@@ -139,7 +139,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.convert_fp16:
param.data = param.data.to(torch.half)
if param.grad is not None:
param.grad = param.grad.to(torch.half).to(target_device)
param.grad = param.grad.to(torch.half)
# move torch parameters to the target device
param.data = param.data.to(target_device)
......
from ast import Try
import functools
from collections import OrderedDict
from typing import Any, Optional
......@@ -54,7 +55,7 @@ class ShardedModelV2(nn.Module):
# In case user didn't use ZeroInitContext
for param in self.module.parameters():
if not hasattr(param, 'col_attr'):
param.col_attr = ShardedParamV2(param, process_group)
param.col_attr = ShardedParamV2(param, process_group, rm_torch_payload=True)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.data])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment