Unverified Commit b36e01d5 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[feat] add buffer_dtype kwarg for more control of batchnorm (#458)

parent 103d33c1
......@@ -127,6 +127,8 @@ class FullyShardedDataParallel(nn.Module):
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``.
buffer_dtype (torch.dtype, Optional):
dtype for buffers for computation. This defaults to ``compute_dtype``.
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after reduction. This is useful when
combined with CPU-based optimizers. It defaults to the value of
......@@ -150,6 +152,7 @@ class FullyShardedDataParallel(nn.Module):
flatten_parameters: bool = True,
cpu_offload: bool = False,
compute_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
):
......@@ -163,6 +166,7 @@ class FullyShardedDataParallel(nn.Module):
self.flatten_parameters = flatten_parameters
self.cpu_offload = cpu_offload
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
......@@ -446,7 +450,7 @@ class FullyShardedDataParallel(nn.Module):
if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.compute_dtype)
self._all_buffers_to(dtype=self.buffer_dtype)
return state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
......@@ -619,9 +623,9 @@ class FullyShardedDataParallel(nn.Module):
self._setup_streams()
if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype)
self._all_buffers_to(device=torch.device("cuda"), dtype=self.buffer_dtype)
else:
self._all_buffers_to(dtype=self.compute_dtype)
self._all_buffers_to(dtype=self.buffer_dtype)
if self._is_root:
# Don't free the full params for the outer-most (root) instance,
......
......@@ -471,6 +471,7 @@ class DeviceAndTypeCheckModule(Base):
expected_param_device: Optional[torch.device] = None,
expected_loss_dtype: Optional[torch.dtype] = None,
expected_loss_device: Optional[torch.device] = None,
expected_buffer_dtype: Optional[torch.device] = None,
):
super().__init__()
self.expected_input_dtype = expected_input_dtype
......@@ -479,8 +480,10 @@ class DeviceAndTypeCheckModule(Base):
self.expected_param_device = expected_param_device
self.expected_loss_dtype = expected_loss_dtype
self.expected_loss_device = expected_loss_device
self.expected_buffer_dtype = expected_buffer_dtype
self.linear = nn.Linear(5, 5)
self.register_buffer("buffer", torch.rand((5,)))
def _check(
self,
......@@ -498,8 +501,9 @@ class DeviceAndTypeCheckModule(Base):
param = self.linear.weight
self._check("param.dtype", param.dtype, self.expected_param_dtype)
self._check("param.device", param.device, self.expected_param_device)
loss = self.linear(x).sum()
self._check("buffer.dtype", self.buffer.dtype, self.expected_buffer_dtype) # type: ignore
x = x + self.buffer
loss = (self.linear(x) + self.buffer).sum()
self._check("loss.dtype", loss.dtype, self.expected_loss_dtype)
self._check("loss.device", loss.device, self.expected_loss_device)
......
......@@ -110,6 +110,18 @@ class TestMixedPrecision(DistributedTest):
torch.float16, # expected_reduce_dtype
)
def test_mixed_precision_autocast_buffer_type_fp32(self):
"""If autocast enabled, loss should be fp32."""
self._spawn_test_case(
{"mixed_precision": True, "buffer_dtype": torch.float32},
True, # autocast enabled
torch.float16, # expected_input_dtype
torch.float16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float16, # expected_reduce_dtype
expected_buffer_type=torch.float32,
)
def test_mixed_precision_autocast_fp32_compute(self):
self._spawn_test_case(
{"mixed_precision": True, "compute_dtype": torch.float32},
......@@ -118,6 +130,7 @@ class TestMixedPrecision(DistributedTest):
torch.float32, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
expected_buffer_type=torch.float32,
)
def test_fp32_reduce_scatter(self):
......@@ -128,6 +141,7 @@ class TestMixedPrecision(DistributedTest):
torch.float16, # expected_param_dtype
torch.float16, # expected_loss_dtype
torch.float32, # expected_reduce_dtype
expected_buffer_type=torch.float16,
)
def test_fp32_reduce_scatter_autocast(self):
......@@ -140,18 +154,42 @@ class TestMixedPrecision(DistributedTest):
torch.float32, # expected_reduce_dtype
)
def _spawn_test_case(self, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype, world_size=2):
def _spawn_test_case(
self,
cfg,
autocast_enabled,
in_dtype,
p_dtype,
loss_dtype,
reduce_dtype,
expected_buffer_type=None,
world_size=2,
):
"""Call test_dtypes inside of torch.multiprocessing.spawn"""
fn = functools.partial(self._test_dtypes, cfg, autocast_enabled, in_dtype, p_dtype, loss_dtype, reduce_dtype)
fn = functools.partial(
self._test_dtypes,
cfg,
autocast_enabled,
in_dtype,
p_dtype,
loss_dtype,
reduce_dtype,
expected_buffer_type=expected_buffer_type,
)
spawn_and_init(fn, world_sizes=[world_size])
@staticmethod
def _test_dtypes(cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group):
def _test_dtypes(
cfg: Dict, autocast, in_dtype, p_dtype, loss_dtype, reduce_dtype, rank, group, expected_buffer_type=None
):
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction
orig_reduce_scatter = torch.distributed.reduce_scatter
model: nn.Module = DeviceAndTypeCheckModule(
expected_input_dtype=in_dtype, expected_param_dtype=p_dtype, expected_loss_dtype=loss_dtype,
expected_input_dtype=in_dtype,
expected_param_dtype=p_dtype,
expected_loss_dtype=loss_dtype,
expected_buffer_dtype=expected_buffer_type,
)
def _reduce_scatter(output, input_list, **kwargs):
......@@ -265,7 +303,7 @@ class TestComparisonToPyTorchDDP(DistributedTest):
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config["mixed_precision"]:
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
......@@ -399,7 +437,9 @@ class TestLocalStateDict(DistributedTest):
@classmethod
def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = self.get_wrapped_model(group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model)
model = self.get_wrapped_model(
group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model, add_bn=False
) # Set bn=True here to show that BN doesn't get updated
state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
assert len(state_1) > 0
......@@ -639,7 +679,7 @@ class TestNoSync(DistributedTest):
def test_no_sync_before_first_forward(self):
group = DummyProcessGroup(rank=0, size=1)
model = self.get_wrapped_model(group, config={})
model = self.get_wrapped_model(group, config={}, add_bn=False)
batch = model.module.get_input(torch.device("cuda"))
with model.no_sync():
output = model(*batch)
......@@ -651,7 +691,7 @@ class TestNoSync(DistributedTest):
@classmethod
def _test_transformer(self, rank, group, config):
model = self.get_wrapped_model(group, config=config)
model = self.get_wrapped_model(group, config=config, add_bn=False)
model.eval() # turn off dropout for the test
self._test_no_sync(model, batch_dim=1)
......@@ -703,7 +743,7 @@ class TestNoSync(DistributedTest):
class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, **unused_kwargs):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
......@@ -714,21 +754,26 @@ class TransformerWithSharedParams(nn.Module):
d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1,
)
self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight
self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))
self.bs = 2
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
src = torch.arange(12, device=device).view(6, 2) # T x B
tgt = torch.arange(8, device=device).view(4, 2) # T x B
src = torch.arange(12, device=device).view(6, self.bs) # T x B
tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
return (src, tgt)
def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids)
src = src + self.vocab_bias + self.long_buffer.type_as(src)
tgt = self.embed_tokens(tgt_ids)
tgt = self.bn(tgt)
x = self.transformer(src, tgt)
return self.output_proj(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment