Commit 36f9a74a authored by ver217's avatar ver217 Committed by Frank Lee
Browse files

fix sharded param hook and unit test

parent 001ca624
import torch import torch
import torch.distributed as dist
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from . import BaseOpHook from . import BaseOpHook
...@@ -21,29 +20,25 @@ class ShardParamHook(BaseOpHook): ...@@ -21,29 +20,25 @@ class ShardParamHook(BaseOpHook):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'ca_attr') assert hasattr(param, 'ca_attr')
param.ca_attr.gather() param.ca_attr.gather()
if dist.get_rank() == 0: param.data = param.ca_attr.payload()
print(f'{param._name} pre fwd shape {param.ca_attr.payload("cpu").shape}')
def post_fwd_exec(self, module: torch.nn.Module, *args): def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'ca_attr') assert hasattr(param, 'ca_attr')
param.ca_attr.shard() param.ca_attr.shard()
if dist.get_rank() == 0: param.data = param.ca_attr.payload()
print(f'{param._name} post fwd shape {param.ca_attr.payload("cpu").shape}')
def pre_bwd_exec(self, module: torch.nn.Module, input, output): def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'ca_attr') assert hasattr(param, 'ca_attr')
param.ca_attr.gather() param.ca_attr.gather()
if dist.get_rank() == 0: param.data = param.ca_attr.payload()
print(f'{param._name} pre bwd shape {param.ca_attr.payload("cpu").shape}')
def post_bwd_exec(self, module: torch.nn.Module, input): def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'ca_attr') assert hasattr(param, 'ca_attr')
param.ca_attr.shard() param.ca_attr.shard()
if dist.get_rank() == 0: param.data = param.ca_attr.payload()
print(f'{param._name} post bwd shape {param.ca_attr.payload("cpu").shape}')
def pre_iter(self): def pre_iter(self):
pass pass
......
...@@ -6,8 +6,7 @@ import torch.distributed as dist ...@@ -6,8 +6,7 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
register_ophooks_recursively)
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
...@@ -109,6 +108,10 @@ class ShardedModelV2(nn.Module): ...@@ -109,6 +108,10 @@ class ShardedModelV2(nn.Module):
if not self._require_backward_grad_sync: if not self._require_backward_grad_sync:
continue continue
p._sharded_grad.write_back() p._sharded_grad.write_back()
# In case some post bwd hook is not fired
for p in self.module.parameters():
if not p.ca_attr.is_sharded:
p.ca_attr.shard()
@torch.no_grad() @torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
......
...@@ -14,7 +14,6 @@ from torch.distributed import ProcessGroup ...@@ -14,7 +14,6 @@ from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from ..sharded_model._zero3_utils import free_storage
from ._utils import has_inf_or_nan from ._utils import has_inf_or_nan
...@@ -63,8 +62,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -63,8 +62,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if hasattr(p, 'ca_attr'): if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
self.master_params[p] = p.ca_attr.payload(self.device) self.master_params[p] = p.ca_attr.payload(self.device)
if dist.get_rank() == 0:
print(f'load payload {p._name} {self.master_params[p].shape}')
else: else:
self.master_params[p] = p.data.to(device=self.device) self.master_params[p] = p.data.to(device=self.device)
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float: if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float:
...@@ -91,23 +88,15 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -91,23 +88,15 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
if hasattr(p, 'ca_attr'): if hasattr(p, 'ca_attr'):
if dist.get_rank() == 0:
print(f'write {p._name} {p.shape} orig_shape {p.ca_attr._origin_shape} \
payload shape {p.ca_attr._param_payload.shape} sharded {p.ca_attr.is_sharded}')
p.ca_attr.set_payload(p.data) p.ca_attr.set_payload(p.data)
# We cannot set p.data to None directly, so we free storage p.data = p.ca_attr.payload()
free_storage(p.data)
return ret return ret
def backward(self, loss: Tensor) -> None: def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED self.optim_state = OptimState.SCALED
if self.model_is_sharded: if self.model_is_sharded:
if dist.get_rank() == 0:
print('sharded model backward')
self.model.backward(loss) self.model.backward(loss)
if dist.get_rank() == 0:
print('sharded model backward done')
else: else:
super().backward(loss) super().backward(loss)
......
from typing import Optional, Tuple, Union
import numpy
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_model._zero3_utils import get_shard
from typing import Union, Tuple, Optional
import numpy
class ShardedParam(object): class ShardedParam(object):
...@@ -28,6 +29,7 @@ class ShardedParam(object): ...@@ -28,6 +29,7 @@ class ShardedParam(object):
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group) self.local_rank = dist.get_rank(self.process_group)
self.is_sharded = False self.is_sharded = False
self.device = device
# Hijack the data payload of param # Hijack the data payload of param
if isinstance(other, torch.nn.Parameter): if isinstance(other, torch.nn.Parameter):
...@@ -50,17 +52,19 @@ class ShardedParam(object): ...@@ -50,17 +52,19 @@ class ShardedParam(object):
self._payload_numel = None self._payload_numel = None
def payload(self, target_device: torch.device): def payload(self, target_device: Optional[torch.device] = None):
r""" r"""
get the payload and move it to target device get the payload and move it to target device
""" """
if target_device is not None:
return self._param_payload.to(target_device) return self._param_payload.to(target_device)
return self._param_payload
def set_payload(self, data: torch.Tensor): def set_payload(self, data: torch.Tensor):
r""" r"""
set payload as data set payload as data
""" """
assert self._param_payload.numel() == data.numel() assert self._param_payload.shape == data.shape
self._param_payload.copy_(data) self._param_payload.copy_(data)
def shard(self): def shard(self):
......
...@@ -3,37 +3,21 @@ from functools import partial ...@@ -3,37 +3,21 @@ from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
LOGGER = get_dist_logger() LOGGER = get_dist_logger()
CONFIG = dict( CONFIG = dict(fp16=dict(mode=None,),
fp16=dict( zero=dict(level=3,
mode=None,
),
zero=dict(
level=3,
verbose=False, verbose=False,
offload_optimizer_config=dict( offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False),
device='cpu', offload_param_config=dict(device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False
),
offload_param_config=dict(
device='cpu',
pin_memory=True, pin_memory=True,
buffer_count=5, buffer_count=5,
buffer_size=1e8, buffer_size=1e8,
max_in_cpu=1e9 max_in_cpu=1e9)),
) parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
def checkpoint_wrapper(module, enable=True): def checkpoint_wrapper(module, enable=True):
...@@ -43,6 +27,7 @@ def checkpoint_wrapper(module, enable=True): ...@@ -43,6 +27,7 @@ def checkpoint_wrapper(module, enable=True):
class Net(nn.Module): class Net(nn.Module):
def __init__(self, checkpoint=False) -> None: def __init__(self, checkpoint=False) -> None:
super().__init__() super().__init__()
self.fc1 = nn.Linear(5, 5) self.fc1 = nn.Linear(5, 5)
...@@ -50,13 +35,7 @@ class Net(nn.Module): ...@@ -50,13 +35,7 @@ class Net(nn.Module):
self.fc3 = nn.Linear(5, 1) self.fc3 = nn.Linear(5, 1)
if checkpoint: if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1) self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [ self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x): def forward(self, x):
for layer in self.layers: for layer in self.layers:
...@@ -111,3 +90,17 @@ def check_params_padding(model, zero_model, loose=False): ...@@ -111,3 +90,17 @@ def check_params_padding(model, zero_model, loose=False):
zero_p = zero_p[:p.size(0)] zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose) assert allclose(p, zero_p, loose=loose)
def check_sharded_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.ca_attr.payload(p.device)
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
...@@ -16,19 +16,18 @@ from colossalai.zero.sharded_model import ShardedModelV2 ...@@ -16,19 +16,18 @@ from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from torch.optim import Adam from torch.optim import Adam
from common import (CONFIG, Net, check_grads, check_grads_padding, check_params, check_params_padding) from common import (CONFIG, Net, check_grads, check_grads_padding, check_params, check_sharded_params_padding)
def run_step(model, optimizer, x, enable_autocast=False): def run_step(model, optimizer, x, enable_autocast=False):
model.train() model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast): with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x) y = model(x)
loss = y.sum() loss = y.sum()
loss = loss.float() loss = loss.float()
if isinstance(model, ShardedModelV2): if isinstance(model, ShardedModelV2):
optimizer.backward(loss) optimizer.backward(loss)
for p in model.parameters():
assert p.ca_attr.is_sharded
else: else:
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -51,7 +50,7 @@ def run_dist(rank, world_size, port): ...@@ -51,7 +50,7 @@ def run_dist(rank, world_size, port):
run_step(model, optim, x, False) run_step(model, optim, x, False)
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
check_grads_padding(model, zero_model) check_grads_padding(model, zero_model)
check_params_padding(model, zero_model) check_sharded_params_padding(model, zero_model)
else: else:
check_grads(model, zero_model) check_grads(model, zero_model)
check_params(model, zero_model) check_params(model, zero_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment