Commit 36f9a74a authored by ver217's avatar ver217 Committed by Frank Lee
Browse files

fix sharded param hook and unit test

parent 001ca624
import torch
import torch.distributed as dist
from colossalai.registry import OPHOOKS
from . import BaseOpHook
......@@ -21,29 +20,25 @@ class ShardParamHook(BaseOpHook):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
if dist.get_rank() == 0:
print(f'{param._name} pre fwd shape {param.ca_attr.payload("cpu").shape}')
param.data = param.ca_attr.payload()
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
if dist.get_rank() == 0:
print(f'{param._name} post fwd shape {param.ca_attr.payload("cpu").shape}')
param.data = param.ca_attr.payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
if dist.get_rank() == 0:
print(f'{param._name} pre bwd shape {param.ca_attr.payload("cpu").shape}')
param.data = param.ca_attr.payload()
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
if dist.get_rank() == 0:
print(f'{param._name} post bwd shape {param.ca_attr.payload("cpu").shape}')
param.data = param.ca_attr.payload()
def pre_iter(self):
pass
......
......@@ -6,8 +6,7 @@ import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
register_ophooks_recursively)
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
......@@ -109,6 +108,10 @@ class ShardedModelV2(nn.Module):
if not self._require_backward_grad_sync:
continue
p._sharded_grad.write_back()
# In case some post bwd hook is not fired
for p in self.module.parameters():
if not p.ca_attr.is_sharded:
p.ca_attr.shard()
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
......
......@@ -14,7 +14,6 @@ from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from ..sharded_model._zero3_utils import free_storage
from ._utils import has_inf_or_nan
......@@ -63,8 +62,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if hasattr(p, 'ca_attr'):
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
self.master_params[p] = p.ca_attr.payload(self.device)
if dist.get_rank() == 0:
print(f'load payload {p._name} {self.master_params[p].shape}')
else:
self.master_params[p] = p.data.to(device=self.device)
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float:
......@@ -91,23 +88,15 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups:
for p in group['params']:
if hasattr(p, 'ca_attr'):
if dist.get_rank() == 0:
print(f'write {p._name} {p.shape} orig_shape {p.ca_attr._origin_shape} \
payload shape {p.ca_attr._param_payload.shape} sharded {p.ca_attr.is_sharded}')
p.ca_attr.set_payload(p.data)
# We cannot set p.data to None directly, so we free storage
free_storage(p.data)
p.data = p.ca_attr.payload()
return ret
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
if self.model_is_sharded:
if dist.get_rank() == 0:
print('sharded model backward')
self.model.backward(loss)
if dist.get_rank() == 0:
print('sharded model backward done')
else:
super().backward(loss)
......
from typing import Optional, Tuple, Union
import numpy
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from typing import Union, Tuple, Optional
import numpy
class ShardedParam(object):
......@@ -28,6 +29,7 @@ class ShardedParam(object):
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
self.is_sharded = False
self.device = device
# Hijack the data payload of param
if isinstance(other, torch.nn.Parameter):
......@@ -50,17 +52,19 @@ class ShardedParam(object):
self._payload_numel = None
def payload(self, target_device: torch.device):
def payload(self, target_device: Optional[torch.device] = None):
r"""
get the payload and move it to target device
"""
if target_device is not None:
return self._param_payload.to(target_device)
return self._param_payload
def set_payload(self, data: torch.Tensor):
r"""
set payload as data
"""
assert self._param_payload.numel() == data.numel()
assert self._param_payload.shape == data.shape
self._param_payload.copy_(data)
def shard(self):
......
......@@ -3,37 +3,21 @@ from functools import partial
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint
LOGGER = get_dist_logger()
CONFIG = dict(
fp16=dict(
mode=None,
),
zero=dict(
level=3,
CONFIG = dict(fp16=dict(mode=None,),
zero=dict(level=3,
verbose=False,
offload_optimizer_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False
),
offload_param_config=dict(
device='cpu',
offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False),
offload_param_config=dict(device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9
)
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
max_in_cpu=1e9)),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
def checkpoint_wrapper(module, enable=True):
......@@ -43,6 +27,7 @@ def checkpoint_wrapper(module, enable=True):
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
......@@ -50,13 +35,7 @@ class Net(nn.Module):
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
def forward(self, x):
for layer in self.layers:
......@@ -111,3 +90,17 @@ def check_params_padding(model, zero_model, loose=False):
zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def check_sharded_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.ca_attr.payload(p.device)
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
......@@ -16,19 +16,18 @@ from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from torch.optim import Adam
from common import (CONFIG, Net, check_grads, check_grads_padding, check_params, check_params_padding)
from common import (CONFIG, Net, check_grads, check_grads_padding, check_params, check_sharded_params_padding)
def run_step(model, optimizer, x, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
if isinstance(model, ShardedModelV2):
optimizer.backward(loss)
for p in model.parameters():
assert p.ca_attr.is_sharded
else:
loss.backward()
optimizer.step()
......@@ -51,7 +50,7 @@ def run_dist(rank, world_size, port):
run_step(model, optim, x, False)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model)
check_params_padding(model, zero_model)
check_sharded_params_padding(model, zero_model)
else:
check_grads(model, zero_model)
check_params(model, zero_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment