Unverified Commit f20d3ddb authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug checkpointing with operation-based API (#1063)



* Debug checkpointing with operation-based API
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Store checkpoint FP8 state on CPU
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix bug where linear op was saving params multiple times
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 50b22da8
...@@ -17,7 +17,9 @@ from typing import Iterable, Union ...@@ -17,7 +17,9 @@ from typing import Iterable, Union
import pytest import pytest
import torch import torch
import transformer_engine.common
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
...@@ -287,3 +289,186 @@ def test_fp8_model_checkpoint( ...@@ -287,3 +289,186 @@ def test_fp8_model_checkpoint(
torch.testing.assert_close( torch.testing.assert_close(
model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item() model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()
) )
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("save_fp8_model", (False, True))
@pytest.mark.parametrize("load_fp8_model", (False, True))
def test_sequential_model(
*,
in_shape: Iterable[int] = (16, 16),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
save_steps: int = 2,
load_steps: int = 2,
fp8: bool,
save_fp8_model: bool,
load_fp8_model: bool,
) -> None:
# Skip invalid configurations
if fp8 or save_fp8_model or load_fp8_model:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")
# FP8 recipe
margin = 2
fp8_format = transformer_engine.common.recipe.Format.E4M3
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
fp8_format=fp8_format,
amax_history_len=8,
amax_compute_algo="max",
)
# Construct model to save to checkpoint
with te.fp8_model_init(enabled=save_fp8_model):
model = te_ops.Sequential(
te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype),
)
with torch.no_grad():
torch.rand(model[0].weight.size(), out=model[0].weight)
torch.rand(model[0].bias.size(), out=model[0].bias)
# Synthetic data
xs_ref = [
torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps)
]
dys_ref = [
torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps)
]
def train_step(
model: te_ops.Sequential,
x: torch.Tensor,
dy: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Helper function to perform training step"""
x = x.detach().clone().requires_grad_()
dy = dy.detach().clone()
with te.fp8_autocast(enabled=fp8, fp8_recipe=recipe):
y = model(x)
y.backward(dy)
with torch.no_grad():
for param in model.parameters():
param += 0.125
return (
y.detach().clone(),
x.grad.detach().clone(),
model[0].weight.detach().float().clone(),
)
# Initial training steps with saved model
ys_ref = []
dxs_ref = []
ws_ref = []
for step in range(save_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
ys_ref.append(y)
dxs_ref.append(dx)
ws_ref.append(w)
# Keep track of FP8 metadata if needed
fp8_meta_ref = dict(input={}, param={}, grad_output={})
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
m_ref["amax"] = m_model.amax_history.detach().clone()
m_ref["scale"] = m_model.scale.detach().clone()
m_ref["scale_inv"] = m_model.scale_inv.detach().clone()
del m_model, m_ref
# Save checkpoint
byte_stream = io.BytesIO()
torch.save(model.state_dict(), byte_stream)
model_bytes = byte_stream.getvalue()
del byte_stream
# More training steps with saved model
for step in range(save_steps, save_steps + load_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
ys_ref.append(y)
dxs_ref.append(dx)
ws_ref.append(w)
# Disturb and destroy model
with torch.no_grad():
for param in model.parameters():
param.zero_()
model[0].basic_ops[0]._fp8_metas = None
del model
# Construct new model to load from checkpoint
with te.fp8_model_init(enabled=load_fp8_model):
model = te_ops.Sequential(
te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype),
)
# Tolerances for numerical checks
tols = {}
if fp8 or save_fp8_model or load_fp8_model:
tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625
exact_tols = dict(rtol=0, atol=0)
# Training steps with dummy data
for step in range(save_steps):
y, dx, w = train_step(
model,
torch.zeros_like(xs_ref[step]),
torch.zeros_like(dys_ref[step]),
)
# Make sure results don't match saved model
with pytest.raises(AssertionError):
torch.testing.assert_close(y, ys_ref[step], **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(dx, dxs_ref[step], **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(w, ws_ref[step], **tols)
# Make sure new model's FP8 metadata doesn't match saved model
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols)
# Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
del model_bytes
# Check that new model's FP8 metadata matches saved model
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols)
torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols)
torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols)
# More training steps with loaded model
for step in range(save_steps, save_steps + load_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
torch.testing.assert_close(y, ys_ref[step], **tols)
torch.testing.assert_close(dx, dxs_ref[step], **tols)
torch.testing.assert_close(w, ws_ref[step], **tols)
...@@ -133,6 +133,38 @@ class Linear(FusedOperation): ...@@ -133,6 +133,38 @@ class Linear(FusedOperation):
# Initialize base class # Initialize base class
super().__init__(ops) super().__init__(ops)
# Register parameters self._has_bias: bool = bias
self.register_parameter("weight", self.basic_ops[0].weight)
self.register_parameter("bias", self.basic_ops[1].bias if bias else None) @property
def weight(self) -> torch.nn.Parameter:
"""Weight tensor
Parameter is owned by `BasicLinear` operation.
"""
return self.basic_ops[0].weight
@weight.setter
def weight(self, value: Optional[torch.nn.Parameter]) -> None:
self.basic_ops[0].weight = value
@property
def bias(self) -> Optional[torch.nn.Parameter]:
"""Bias tensor
Parameter is owned by `Bias` operation.
"""
if self._has_bias:
return self.basic_ops[1].bias
return None
@bias.setter
def bias(self, value: Optional[torch.nn.Parameter]) -> None:
if self._has_bias:
self.basic_ops[1].bias = value
elif value is not None:
raise ValueError(
"Attempted to set bias parameter in Linear operation "
"that does not have bias enabled"
)
...@@ -8,6 +8,7 @@ from __future__ import annotations ...@@ -8,6 +8,7 @@ from __future__ import annotations
import abc import abc
from collections.abc import Iterable from collections.abc import Iterable
import dataclasses import dataclasses
import pickle
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -504,6 +505,161 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -504,6 +505,161 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
basic_op_kwargs=[kwargs], basic_op_kwargs=[kwargs],
) )
def get_extra_state(self) -> Optional[torch.Tensor]:
"""Serialize extra state
Contains metadata for FP8 casting.
"""
# This implementation is working around a few issues:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# It seems that ONNX export experiences issues with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
# Thus, we want to avoid putting extra state on the GPU
# since it may be loaded on the wrong device.
# (3) The extra state consists of many small tensors. If we
# want to copy them all to CPU, then we need to avoid the
# overhead of many GPU-CPU memory transfers.
#
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363
# Return immediately if op has no FP8 state
has_fp8_state = any(
self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output")
)
if not has_fp8_state:
return None
def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst = torch.empty_like(src, device="cpu")
dst.copy_(src, non_blocking=True)
return dst
# Store FP8 state
state = {}
for mode in ("input", "param", "grad_output"):
# Get state for a given FP8 tensor
if self.num_fp8_scales(mode) == 0:
state[mode] = None
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue
state[mode] = {}
# Store tensors
if "scaling_fwd" in fp8_meta:
state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
state[mode]["scale_inv_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale_inv)
state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
if "scaling_bwd" in fp8_meta:
state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
state[mode]["scale_inv_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale_inv)
state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
# Store other picklable items
extra = {}
for key, val in fp8_meta.items():
if key == "buffer_index_and_autocast_key":
continue
if not isinstance(val, (bool, int, float, str, tuple, list)):
continue
extra[key] = val
state[mode]["extra_fp8_variables"] = extra
# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized
def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load extra state"""
if state is None:
return
# Deserialize state from byte tensor
state = pickle.loads(state.detach().numpy(force=True).tobytes())
if state is None:
return
def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
"""Helper function to copy tensor from CPU
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
if src.size() != dst.size():
dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
dst.copy_(src, non_blocking=True)
# Load FP8 state
for mode in ("input", "param", "grad_output"):
# Get state for a given FP8 tensor
if mode not in state:
continue
if self.num_fp8_scales(mode) == 0:
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue
# Load extra state
fp8_meta.update(state[mode]["extra_fp8_variables"])
if "amax_history_fwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0)
elif "amax_history_bwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0)
if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Load tensors
fp8_meta = self.get_fp8_meta(mode)
if "scaling_fwd" in fp8_meta:
fp8_meta_fwd = fp8_meta["scaling_fwd"]
copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale)
copy_tensor(state[mode]["scale_inv_fwd"], fp8_meta_fwd.scale_inv)
copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history)
if "scaling_bwd" in fp8_meta:
fp8_meta_bwd = fp8_meta["scaling_bwd"]
copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale)
copy_tensor(state[mode]["scale_inv_bwd"], fp8_meta_bwd.scale_inv)
copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history)
# Finish CPU-GPU memory transfers
torch.cuda.synchronize()
def _load_from_state_dict(self, *args, **kwargs) -> None:
"""Load state"""
# In the base PyTorch module class, the extra state is loaded
# _after_ the parameters. However, copying values into FP8
# parameters requires an FP8 cast, which uses a scaling factor
# from the operation's FP8 metadata. The FP8 metadata is
# included in the operation's extra state, so we need to
# manually load the extra state before loading parameters.
state_dict, prefix = args[0], args[1]
extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
super()._load_from_state_dict(*args, **kwargs)
class FusedOperation(FusibleOperation): class FusedOperation(FusibleOperation):
"""Compound tensor operation supported by the operation fuser """Compound tensor operation supported by the operation fuser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment