"...git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "80e8c1d3d29e7c1e8f3bd9eb02b3c99e874bdb9f"
Commit 70814dc2 authored by ver217's avatar ver217 Committed by Frank Lee
Browse files

fix master params dtype

parent 795210dd
...@@ -26,7 +26,7 @@ class ShardedAdam(ColossalaiOptimizer): ...@@ -26,7 +26,7 @@ class ShardedAdam(ColossalaiOptimizer):
def __init__(self, def __init__(self,
adam_optim: Optimizer, adam_optim: Optimizer,
sharded_model: nn.Module, sharded_model: Union[nn.Module, ShardedModelV2],
cpu_offload: bool = False, cpu_offload: bool = False,
initial_scale: float = 2**32, initial_scale: float = 2**32,
min_scale: float = 1, min_scale: float = 1,
...@@ -61,9 +61,11 @@ class ShardedAdam(ColossalaiOptimizer): ...@@ -61,9 +61,11 @@ class ShardedAdam(ColossalaiOptimizer):
for p in group['params']: for p in group['params']:
if hasattr(p, 'ca_attr'): if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
self.master_params[p] = p.ca_attr.payload(self.device).to(torch.float) self.master_params[p] = p.ca_attr.payload(self.device)
else: else:
self.master_params[p] = p.data.to(torch.float) self.master_params[p] = p.data.to(device=self.device)
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float:
self.master_params[p] = self.master_params[p].to(torch.float)
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
# unscale grads if scaled # unscale grads if scaled
...@@ -85,8 +87,9 @@ class ShardedAdam(ColossalaiOptimizer): ...@@ -85,8 +87,9 @@ class ShardedAdam(ColossalaiOptimizer):
# Write master param to payload and set p.data to None # Write master param to payload and set p.data to None
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
# TODO: update payload if hasattr(p, 'ca_attr'):
p.data = None # TODO: update payload
p.data = None
return ret return ret
def backward(self, loss: Tensor) -> None: def backward(self, loss: Tensor) -> None:
...@@ -129,10 +132,7 @@ class ShardedAdam(ColossalaiOptimizer): ...@@ -129,10 +132,7 @@ class ShardedAdam(ColossalaiOptimizer):
# all-reduce over model parallel group # all-reduce over model parallel group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group) dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group)
if self._found_overflow.item() > 0: return self._found_overflow.item() > 0
return True
else:
return False
def _unscale_grads(self): def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED assert self.optim_state == OptimState.SCALED
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment