Unverified Commit b36e01d5 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[feat] add buffer_dtype kwarg for more control of batchnorm (#458)

parent 103d33c1
...@@ -127,6 +127,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -127,6 +127,8 @@ class FullyShardedDataParallel(nn.Module):
dtype for full parameters for computation. This defaults to dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case ``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``. it defaults to ``torch.float16``.
buffer_dtype (torch.dtype, Optional):
dtype for buffers for computation. This defaults to ``compute_dtype``.
move_grads_to_cpu (bool, Optional): move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after reduction. This is useful when move gradient shard to CPU after reduction. This is useful when
combined with CPU-based optimizers. It defaults to the value of combined with CPU-based optimizers. It defaults to the value of
...@@ -150,6 +152,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -150,6 +152,7 @@ class FullyShardedDataParallel(nn.Module):
flatten_parameters: bool = True, flatten_parameters: bool = True,
cpu_offload: bool = False, cpu_offload: bool = False,
compute_dtype: Optional[torch.dtype] = None, compute_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None, move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25, bucket_cap_mb: int = 25,
): ):
...@@ -163,6 +166,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -163,6 +166,7 @@ class FullyShardedDataParallel(nn.Module):
self.flatten_parameters = flatten_parameters self.flatten_parameters = flatten_parameters
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32) self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb self.bucket_cap_mb = bucket_cap_mb
...@@ -446,7 +450,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -446,7 +450,7 @@ class FullyShardedDataParallel(nn.Module):
if self.mixed_precision: if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16. # In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.compute_dtype) self._all_buffers_to(dtype=self.buffer_dtype)
return state_dict return state_dict
# TODO (Min): figuring out how to do typing for this overloaded function. # TODO (Min): figuring out how to do typing for this overloaded function.
...@@ -619,9 +623,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -619,9 +623,9 @@ class FullyShardedDataParallel(nn.Module):
self._setup_streams() self._setup_streams()
if self.cpu_offload: # Buffers stay on GPU, and don't get sharded if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype) self._all_buffers_to(device=torch.device("cuda"), dtype=self.buffer_dtype)
else: else:
self._all_buffers_to(dtype=self.compute_dtype) self._all_buffers_to(dtype=self.buffer_dtype)
if self._is_root: if self._is_root:
# Don't free the full params for the outer-most (root) instance, # Don't free the full params for the outer-most (root) instance,
......
...@@ -471,6 +471,7 @@ class DeviceAndTypeCheckModule(Base): ...@@ -471,6 +471,7 @@ class DeviceAndTypeCheckModule(Base):
expected_param_device: Optional[torch.device] = None, expected_param_device: Optional[torch.device] = None,
expected_loss_dtype: Optional[torch.dtype] = None, expected_loss_dtype: Optional[torch.dtype] = None,
expected_loss_device: Optional[torch.device] = None, expected_loss_device: Optional[torch.device] = None,
expected_buffer_dtype: Optional[torch.device] = None,
): ):
super().__init__() super().__init__()
self.expected_input_dtype = expected_input_dtype self.expected_input_dtype = expected_input_dtype
...@@ -479,8 +480,10 @@ class DeviceAndTypeCheckModule(Base): ...@@ -479,8 +480,10 @@ class DeviceAndTypeCheckModule(Base):
self.expected_param_device = expected_param_device self.expected_param_device = expected_param_device
self.expected_loss_dtype = expected_loss_dtype self.expected_loss_dtype = expected_loss_dtype
self.expected_loss_device = expected_loss_device self.expected_loss_device = expected_loss_device
self.expected_buffer_dtype = expected_buffer_dtype
self.linear = nn.Linear(5, 5) self.linear = nn.Linear(5, 5)
self.register_buffer("buffer", torch.rand((5,)))
def _check( def _check(
self, self,
...@@ -498,8 +501,9 @@ class DeviceAndTypeCheckModule(Base): ...@@ -498,8 +501,9 @@ class DeviceAndTypeCheckModule(Base):
param = self.linear.weight param = self.linear.weight
self._check("param.dtype", param.dtype, self.expected_param_dtype) self._check("param.dtype", param.dtype, self.expected_param_dtype)
self._check("param.device", param.device, self.expected_param_device) self._check("param.device", param.device, self.expected_param_device)
self._check("buffer.dtype", self.buffer.dtype, self.expected_buffer_dtype) # type: ignore
loss = self.linear(x).sum() x = x + self.buffer
loss = (self.linear(x) + self.buffer).sum()
self._check("loss.dtype", loss.dtype, self.expected_loss_dtype) self._check("loss.dtype", loss.dtype, self.expected_loss_dtype)
self._check("loss.device", loss.device, self.expected_loss_device) self._check("loss.device", loss.device, self.expected_loss_device)
......
...@@ -110,6 +110,18 @@ class TestMixedPrecision(DistributedTest): ...@@ -110,6 +110,18 @@ class TestMixedPrecision(DistributedTest):
torch.float16, # expected_reduce_dtype torch.float16, # expected_reduce_dtype
) )
def test_mixed_precision_autocast_buffer_type_fp32(self):
"""If autocast enabled, loss should be fp32."""
self._spawn_test_case(
{"mixed_precision": True, "buffer_dtype": torch.float32},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float16, # expected_reduce_dtype
expected_buffer_type=torch.float32,
)
def test_mixed_precision_autocast_fp32_compute(self): def test_mixed_precision_autocast_fp32_compute(self):
self._spawn_test_case( self._spawn_test_case(
{"mixed_precision": True, "compute_dtype": torch.float32}, {"mixed_precision": True, "compute_dtype": torch.float32},
...@@ -118,6 +130,7 @@ class TestMixedPrecision(DistributedTest): ...@@ -118,6 +130,7 @@ class TestMixedPrecision(DistributedTest):
torch.float32, # expected_param_dtype torch.float32, # expected_param_dtype
torch.float32, # expected_loss_dtype torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype torch.float32, # expected_reduce_dtype
expected_buffer_type=torch.float32,
) )
def test_fp32_reduce_scatter(self): def test_fp32_reduce_scatter(self):
...@@ -128,6 +141,7 @@ class TestMixedPrecision(DistributedTest): ...@@ -128,6 +141,7 @@ class TestMixedPrecision(DistributedTest):
torch.float16, # expected_param_dtype torch.float16, # expected_param_dtype
torch.float16, # expected_loss_dtype torch.float16, # expected_loss_dtype
torch.float32, # expected_reduce_dtype torch.float32, # expected_reduce_dtype
expected_buffer_type=torch.float16,
) )
def test_fp32_reduce_scatter_autocast(self): def test_fp32_reduce_scatter_autocast(self):
...@@ -140,18 +154,42 @@ class TestMixedPrecision(DistributedTest): ...@@ -140,18 +154,42 @@ class TestMixedPrecision(DistributedTest):
torch.float32, # expected_reduce_dtype torch.float32, # expected_reduce_dtype
) )
def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype, world_size=2): def _spawn_test_case(
self,
cfg,
autocast_enabled,
in_dtype,
p_dtype,
loss_dtype,
reduce_dtype,
expected_buffer_type=None,
world_size=2,
):
"""Call test_dtypes inside of torch.multiprocessing.spawn""" """Call test_dtypes inside of torch.multiprocessing.spawn"""
fn = functools.partial(self._test_dtypes, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype) fn = functools.partial(
self._test_dtypes,
cfg,
autocast_enabled,
in_dtype,
p_dtype,
loss_dtype,
reduce_dtype,
expected_buffer_type=expected_buffer_type,
)
spawn_and_init(fn, world_sizes=[world_size]) spawn_and_init(fn, world_sizes=[world_size])
@staticmethod @staticmethod
def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group): def _test_dtypes(
cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group, expected_buffer_type=None
):
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction # Patch torch.distributed.reduce_scatter to check the dtype of the reduction
orig_reduce_scatter = torch.distributed.reduce_scatter orig_reduce_scatter = torch.distributed.reduce_scatter
model: nn.Module = DeviceAndTypeCheckModule( model: nn.Module = DeviceAndTypeCheckModule(
expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype, expected_input_dtype=in_dtype,
expected_param_dtype=p_dtype,
expected_loss_dtype=loss_dtype,
expected_buffer_dtype=expected_buffer_type,
) )
def _reduce_scatter(output, input_list, **kwargs): def _reduce_scatter(output, input_list, **kwargs):
...@@ -265,7 +303,7 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -265,7 +303,7 @@ class TestComparisonToPyTorchDDP(DistributedTest):
def _test_identical_outputs( def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
): ):
if config["mixed_precision"]: if config.get("mixed_precision", False):
autocast = True autocast = True
# Force the compute dtype to be torch.float32 so that we get # Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that # identical results as PyTorch DDP when using autocast. Note that
...@@ -399,7 +437,9 @@ class TestLocalStateDict(DistributedTest): ...@@ -399,7 +437,9 @@ class TestLocalStateDict(DistributedTest):
@classmethod @classmethod
def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23): def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it""" """Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = self.get_wrapped_model(group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model) model = self.get_wrapped_model(
group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model, add_bn=False
) # Set bn=True here to show that BN doesn't get updated
state_1 = model.local_state_dict() state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()} state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
assert len(state_1) > 0 assert len(state_1) > 0
...@@ -639,7 +679,7 @@ class TestNoSync(DistributedTest): ...@@ -639,7 +679,7 @@ class TestNoSync(DistributedTest):
def test_no_sync_before_first_forward(self): def test_no_sync_before_first_forward(self):
group = DummyProcessGroup(rank=0, size=1) group = DummyProcessGroup(rank=0, size=1)
model = self.get_wrapped_model(group, config={}) model = self.get_wrapped_model(group, config={}, add_bn=False)
batch = model.module.get_input(torch.device("cuda")) batch = model.module.get_input(torch.device("cuda"))
with model.no_sync(): with model.no_sync():
output = model(*batch) output = model(*batch)
...@@ -651,7 +691,7 @@ class TestNoSync(DistributedTest): ...@@ -651,7 +691,7 @@ class TestNoSync(DistributedTest):
@classmethod @classmethod
def _test_transformer(self, rank, group, config): def _test_transformer(self, rank, group, config):
model = self.get_wrapped_model(group, config=config) model = self.get_wrapped_model(group, config=config, add_bn=False)
model.eval() # turn off dropout for the test model.eval() # turn off dropout for the test
self._test_no_sync(model, batch_dim=1) self._test_no_sync(model, batch_dim=1)
...@@ -703,7 +743,7 @@ class TestNoSync(DistributedTest): ...@@ -703,7 +743,7 @@ class TestNoSync(DistributedTest):
class TransformerWithSharedParams(nn.Module): class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, **unused_kwargs): def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
super().__init__() super().__init__()
self.rank = group.rank() self.rank = group.rank()
self.world_size = group.size() self.world_size = group.size()
...@@ -714,21 +754,26 @@ class TransformerWithSharedParams(nn.Module): ...@@ -714,21 +754,26 @@ class TransformerWithSharedParams(nn.Module):
d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1,
) )
self.output_proj = nn.Linear(d_model, d_vocab) self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights # share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight self.output_proj.weight = self.embed_tokens.weight
self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,))) self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long)) self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))
self.bs = 2
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
def get_input(self, device): def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic torch.manual_seed(1 + self.rank) # keep everything deterministic
src = torch.arange(12, device=device).view(6, 2) # T x B src = torch.arange(12, device=device).view(6, self.bs) # T x B
tgt = torch.arange(8, device=device).view(4, 2) # T x B tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
return (src, tgt) return (src, tgt)
def forward(self, src_ids, tgt_ids): def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids) src = self.embed_tokens(src_ids)
src = src + self.vocab_bias + self.long_buffer.type_as(src) src = src + self.vocab_bias + self.long_buffer.type_as(src)
tgt = self.embed_tokens(tgt_ids) tgt = self.embed_tokens(tgt_ids)
tgt = self.bn(tgt)
x = self.transformer(src, tgt) x = self.transformer(src, tgt)
return self.output_proj(x) return self.output_proj(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment