Unverified Commit 0da60449 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Debug checkpointing with te.Sequential (#1629)



* Debug checkpointing with te.Sequential
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 20e95ba3
......@@ -5,6 +5,7 @@
from __future__ import annotations
from collections.abc import Iterable
import io
import math
from typing import Optional
......@@ -1885,3 +1886,118 @@ class TestFusedOps:
torch.testing.assert_close(y2_test, y2_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
class TestCheckpointing:
"""Tests for checkpointing"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear(
self,
*,
pre_checkpoint_steps: int = 2,
post_checkpoint_steps: int = 2,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
) -> None:
"""Check checkpointing with linear op"""
# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)
# Construct model
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model_save = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25)
# Warmup training steps
for _ in range(pre_checkpoint_steps):
x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
dy = torch.randn(out_shape, dtype=dtype, device=device)
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_save(x)
y.backward(dy)
optim_save.step()
# Save checkpoint
byte_stream = io.BytesIO()
torch.save(
{"model": model_save.state_dict(), "optim": optim_save.state_dict()},
byte_stream,
)
checkpoint_bytes = byte_stream.getvalue()
del byte_stream
# Synthetic data for evaluation
xs_save = [
torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
for _ in range(post_checkpoint_steps)
]
with torch.no_grad():
xs_load = [x.clone().requires_grad_() for x in xs_save]
dys = [
torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps)
]
# Training steps with original model
ys_save = []
for i in range(post_checkpoint_steps):
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_save(xs_save[i])
y.backward(dys[i])
optim_save.step()
ys_save.append(y)
# Load checkpoint
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model_load = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25)
state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False)
model_load.load_state_dict(state_dict["model"])
optim_load.load_state_dict(state_dict["optim"])
# Training steps with loaded model
ys_load = []
for i in range(post_checkpoint_steps):
optim_load.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_load(xs_load[i])
y.backward(dys[i])
optim_load.step()
ys_load.append(y)
# Check that original and loaded model match exactly
tols = {"rtol": 0, "atol": 0}
for param_load, param_save in zip(model_load.parameters(), model_save.parameters()):
torch.testing.assert_close(param_load, param_save, **tols)
torch.testing.assert_close(param_load.grad, param_save.grad, **tols)
for y_load, y_save in zip(ys_load, ys_save):
torch.testing.assert_close(y_load, y_save, **tols)
for x_load, x_save in zip(xs_load, xs_save):
torch.testing.assert_close(x_load.grad, x_save.grad, **tols)
......@@ -19,6 +19,7 @@ from ..fp8 import (
DelayedScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
fp8_autocast,
)
from ..tensor import Quantizer
......@@ -508,7 +509,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def get_extra_state(self) -> torch.Tensor:
"""Serialize extra state
Contains metadata for FP8 casting.
Contains metadata for quantization recipe.
"""
......@@ -540,23 +541,27 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
dst.copy_(src, non_blocking=True)
return dst
# Store FP8 state
# Store quantizer state if needed
state = {}
for mode in ("forward", "backward"):
# Get state for a given FP8 tensor
if self.num_quantizers(mode) == 0:
# Skip if op has no quantizer state
if self._fp8_metas is None or self._fp8_metas[mode] is None:
continue
fp8_meta = self.get_fp8_meta(mode)
# Quantizer state
fp8_meta = self._fp8_metas[mode]
state[mode] = {}
state[mode]["recipe"] = fp8_meta["recipe"]
# Store tensors
if "scaling_fwd" in fp8_meta:
state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
if "scaling_bwd" in fp8_meta:
state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
# Copy tensors to CPU and store
if state[mode]["recipe"].delayed():
if mode == "forward":
state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
if mode == "backward":
state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
# Store other picklable items
extra = {}
......@@ -595,37 +600,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
dst.copy_(src, non_blocking=True)
# Load FP8 state
# Load quantizer state if needed
for mode in ("forward", "backward"):
# Get state for a given FP8 tensor
# Skip if checkpoint has no quantizer state
if mode not in state:
continue
if self.num_quantizers(mode) == 0:
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue
# Load extra state
# Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state()
fp8_meta = self._fp8_metas[mode]
# Load extra items
fp8_meta["recipe"] = state[mode]["recipe"]
fp8_meta.update(state[mode]["extra_fp8_variables"])
if "amax_history_fwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0)
elif "amax_history_bwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0)
if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Load tensors
fp8_meta = self.get_fp8_meta(mode)
if "scaling_fwd" in fp8_meta:
fp8_meta_fwd = fp8_meta["scaling_fwd"]
copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale)
copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history)
if "scaling_bwd" in fp8_meta:
fp8_meta_bwd = fp8_meta["scaling_bwd"]
copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale)
copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history)
if state[mode]["recipe"].delayed():
if mode == "forward":
copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale)
copy_tensor(
state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history
)
if mode == "backward":
copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale)
copy_tensor(
state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history
)
# Finish CPU-GPU memory transfers
torch.cuda.synchronize()
......
......@@ -347,6 +347,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
dtype: torch.dtype,
shape: torch.shape,
) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__
......@@ -361,10 +362,11 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
dtype=dtype,
shape=shape,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
"""Custom pickling"""
return (
MXFP8Tensor._make_in_reduce_ex,
(
......@@ -374,6 +376,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self._columnwise_scale_inv,
self._fp8_dtype,
self.dtype,
self.shape,
),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment