Unverified Commit 599258f9 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

ZeRO 3 Offload (#834)



* Squash stage3 v1 (#146)
Co-authored-by: default avatarSamyam <samyamr@microsoft.com>
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
Co-authored-by: default avatarSamyam Rajbhandari <samyamr@microsoft.com>
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatareltonzheng <eltonz@microsoft.com>

* Fix correctness bug (#147)

* formatting fix (#150)

* stage3 bugfix (API) update and simplified FP16 Z3 tests (#151)

* fp16 Z3 API update and bugfix

* revert debug change

* ZeRO-3 detach and race condition bugfixes (#149)

* trying out ZeRO-3 race condition fix

* CUDA sync instead of stream

* reduction stream sync

* remove commented code

* Fix optimizer state_dict KeyError (#148)
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>

* fix for smaller SGS sizes, ensures each grad is backed by unique tensors (#152)

* Simplifying the logic for getting averaged gradients (#153)

* skip for now

* Z3 Docs redux (#154)

* removing some TODOs and commented code (#155)

* New Z3 defaults (#156)
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>

* formatting

* megatron external params
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: default avatarShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: default avatarShaden Smith <ShadenTSmith@gmail.com>
Co-authored-by: default avatareltonzheng <eltonz@microsoft.com>
parent ba33e86e
import os
import torch
import pytest
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from common import distributed_test
def setup_serial_env():
# Setup for a serial run
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29503'
os.environ['LOCAL_RANK'] = '0'
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
def test_scattered_init_dist():
setup_serial_env()
assert not torch.distributed.is_initialized()
with deepspeed.zero.Init():
assert torch.distributed.is_initialized()
@distributed_test(world_size=2)
def test_scatter_gather():
with deepspeed.zero.Init():
l = torch.nn.Linear(6, 3)
assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE
assert l.weight.numel() == 1
# Ensure there is no impact outside the context
l2 = torch.nn.Linear(6, 3)
assert not hasattr(l2.weight, 'ds_status')
assert l2.weight.numel() == l2.in_features * l2.out_features
with deepspeed.zero.GatheredParameters(l.weight):
assert l.weight.ds_status == ZeroParamStatus.AVAILABLE
assert l.weight.numel() == l.in_features * l.out_features
@distributed_test(world_size=2)
def test_gather_update():
with deepspeed.zero.Init():
l = torch.nn.Linear(4, 2)
assert l.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE
# Gather and make a change
with deepspeed.zero.GatheredParameters(l.weight, modifier_rank=1):
assert l.weight.ds_status == ZeroParamStatus.AVAILABLE
if torch.distributed.get_rank() == 1:
with torch.no_grad():
l.weight.zero_()
# should now be scattered again
# Now gather again and ensure the change is global
with deepspeed.zero.GatheredParameters(l.weight):
# all ranks compare
assert torch.equal(l.weight, torch.zeros_like(l.weight))
@pytest.mark.skip('WIP')
def test_external_param():
setup_serial_env()
print()
class ExtLinear(torch.nn.Module):
def __init__(self, dim=10, copycat=None):
super().__init__()
self.dim = dim
self.linear = torch.nn.Linear(dim, dim)
if copycat is not None:
with deepspeed.zero.GatheredParameters(self.linear.weight,
modifier_rank=0), \
torch.no_grad():
self.linear.weight.copy_(copycat.linear.weight)
if hasattr(self.linear.weight, 'ds_id'):
print('registering')
super().ds_register_external_parameter('samyam', self.linear.weight)
def forward(self, input):
yamsam = self.linear(input)
if hasattr(self.linear.weight, 'ds_status'):
assert self.linear.weight.ds_status == ZeroParamStatus.AVAILABLE
jeff = torch.nn.functional.linear(yamsam, self.linear.weight)
return jeff
l1_base = ExtLinear().half().cuda()
l2_base = ExtLinear().half().cuda()
input = torch.rand(10).half().cuda()
l1_base_out = l1_base(input.clone().detach())
l2_base_out = l2_base(input.clone().detach())
with deepspeed.zero.Init():
l1_test = ExtLinear(copycat=l1_base).cuda()
#l2_test = ExtLinear(copycat=l2_base).cuda()
assert l1_test.linear.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE
# XXX l1 and l2 share their external parameter (l2.linear.weight)
assert l1_test.linear.weight.ds_status == ZeroParamStatus.NOT_AVAILABLE
l1_test_out = l1_test(input.clone().detach())
#assert torch.allclose(l1_base_out, l1_test_out)
#l2_test_out = l2_test(input.clone().detach())
#assert torch.allclose(l2_base_out, l2_test_out)
def test_scatter_halftype():
setup_serial_env()
with deepspeed.zero.Init():
l = torch.nn.Linear(10, 10)
assert l.weight.ds_tensor.dtype == torch.float16
y = torch.LongTensor([3, 3])
assert y.dtype == torch.long
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment