Unverified Commit 8e757a45 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] RNG state support for model parallelism (#473)



* Add class for RNG state tracker.
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix docs for checkpoint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4ae34765
......@@ -31,6 +31,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
:members: swap_key_value_dict
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.checkpoint
......
......@@ -4,7 +4,6 @@
import math
import os
import contextlib
from typing import List, Optional
import pytest
import copy
......@@ -12,8 +11,6 @@ import copy
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from transformer_engine.pytorch.utils import (
init_method_normal,
......@@ -25,9 +22,10 @@ from transformer_engine.pytorch import (
MultiheadAttention, RMSNorm, TransformerLayer
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
seed = 1234
rng_str = "rng_state"
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Record initial RNG state from script run.
......@@ -91,119 +89,14 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float)
raise AssertionError(msg)
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device("cuda")
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
def reset_rng_states() -> None:
# revert back to initial RNG state.
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception("seed {} already exists".format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception("cuda rng state {} already exists".format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=rng_str):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception("cuda rng state {} is not added".format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add(rng_str, seed)
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker():
......
......@@ -15,6 +15,7 @@ from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
......
......@@ -14,6 +14,10 @@ from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
"partition_dim": -1,
......@@ -299,17 +303,14 @@ def checkpoint(
Parameters
----------
function: Callable
whether or not to enable fp8
pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations: bool
if set to `True`, the first tensor argument is distributed across the
specified tensor parallel group (`tp_group`) before saving it for the
backward pass.
get_cuda_rng_tracker: `Callable`
python function with the functionality to retrieve a state via
:attr:`state = get_cuda_rng_tracker().get_states()` and to reset the state via
:attr:`get_cuda_rng_tracker().set_states(state)`. This is used to ensure any
extra cuda rng state or general global state can be reproduced across the 2
forward phases; original and recompute.
python callable which returns an instance of :func:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
args : tuple
......@@ -328,6 +329,100 @@ def checkpoint(
)
class CudaRNGStatesTracker:
"""
For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the `add` method, a
cuda rng state is initialized based on the input `seed` and is assigned to `name`.
Later, by forking the rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""
Set to the initial state (no tracker).
"""
self.states_ = {}
self.seeds_ = set()
def get_states(self) -> Dict[str, torch.Tensor]:
"""
Get rng states. Copy the dictionary so we have direct pointers
to the states, not just a pointer to the dictionary.
"""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states: Dict[str, torch.Tensor]) -> None:
"""
Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility.
states: Dict[str, torch.Tensor]
A mapping from string names to RNG states.
"""
self.states_ = states
def add(self, name: str, seed: int) -> None:
"""
Adds a new RNG state.
name: str
string identifier for the RNG state.
seed: int
PyTorch seed for the RNG state.
"""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception(f"seed {seed} already exists")
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception(f"cuda rng state {name} already exists")
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextmanager
def fork(self, name: str = "model-parallel-rng") -> None:
"""
Fork the cuda rng state, perform operations, and exit with
the original state.
name: str
string identifier for the RNG state.
"""
# Check if we have added the state
if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added")
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
def reduce_scatter_along_first_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment