Unverified Commit 73f8d90f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] cuda graph support (#575)



* FP8 cuda graphs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix numerics
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* exclude torch compile from numerics tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More numerics fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm fusion from unfused path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
parent 1b20f2d6
......@@ -157,7 +157,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
......@@ -238,13 +238,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _stop_comm, 0));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
......@@ -350,11 +347,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
......@@ -469,13 +467,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
_ub_comm->sms = ori_sms;
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main);
......@@ -506,7 +504,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
}
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(),
......@@ -805,14 +803,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
if (_aggregate2) {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
......@@ -877,21 +876,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
}
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id =
(num_steps + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
} else {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
......@@ -936,16 +923,19 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
}
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id = (_tp_size + _stream_compute.size() - 1) % _stream_compute.size();
}
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return D;
} // split_overlap_ag
......
......@@ -43,6 +43,7 @@
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <transformer_engine/cast_transpose_noop.h>
namespace transformer_engine {
......
......@@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input,
);
void fused_cast_transpose_noop(at::Tensor input,
at::Tensor noop,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
);
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
......@@ -263,6 +274,17 @@ at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
);
void fp8_transpose_noalloc(at::Tensor input,
at::Tensor output,
transformer_engine::DType otype
);
void fp8_transpose_noalloc_noop(at::Tensor input,
at::Tensor output,
at::Tensor noop,
transformer_engine::DType otype
);
/***************************************************************************************************
* Activations
**************************************************************************************************/
......@@ -559,16 +581,13 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
* FP8 recipe
**************************************************************************************************/
void fused_amax_and_scale_update(const at::Tensor &amax_history,
const at::Tensor &scale,
const at::Tensor &scale_inv,
const at::Tensor &scale_inv_mask,
at::Tensor updated_amax_history,
at::Tensor updated_scale,
at::Tensor updated_scale_inv,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin);
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
std::vector<at::Tensor> scale_invs,
const std::string &amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin);
/***************************************************************************************************
* Rotary positional embedding
......
......@@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop,
"Fused Cast + Transpose with noop option");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad,
......@@ -67,6 +69,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop,
"Transpose with FP8 I/O with noop option.");
m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output");
m.def("geglu", &geglu, "GeGLU with FP8 output");
......@@ -82,9 +87,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
m.def("fused_amax_and_scale_update",
&fused_amax_and_scale_update,
"Update amax history and FP8 scale");
m.def("fused_amax_and_scale_update_after_reduction",
&fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction");
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD");
......
......@@ -11,24 +11,50 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
void fused_amax_and_scale_update(const at::Tensor &amax_history,
const at::Tensor &scale,
const at::Tensor &scale_inv,
const at::Tensor &scale_inv_mask,
at::Tensor updated_amax_history,
at::Tensor updated_scale,
at::Tensor updated_scale_inv,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin) {
nvte_delayed_scaling_recipe_amax_and_scale_update(
makeTransformerEngineTensor(amax_history).data(),
makeTransformerEngineTensor(scale).data(),
makeTransformerEngineTensor(scale_inv).data(),
makeTransformerEngineTensor(scale_inv_mask).data(),
makeTransformerEngineTensor(updated_amax_history).data(),
makeTransformerEngineTensor(updated_scale).data(),
makeTransformerEngineTensor(updated_scale_inv).data(),
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
std::vector<at::Tensor> scale_invs,
const std::string &amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin) {
using namespace transformer_engine;
size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors);
std::vector<Tensor> t_scale_invs(num_tensors);
std::vector<NVTETensor> te_amax_histories(num_tensors);
std::vector<NVTETensor> te_scales(num_tensors);
std::vector<NVTETensor> te_scale_invs(num_tensors);
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories[i].data.dptr = amax_histories[i].data_ptr();
auto amax_sizes = amax_histories[i].sizes().vec();
std::vector<size_t> amax_shape{amax_sizes.begin(), amax_sizes.end()};
t_amax_histories[i].data.shape = amax_shape;
t_amax_histories[i].data.dtype = DType::kFloat32;
t_scales[i].data.dptr = scales[i].data_ptr();
auto scale_sizes = scales[i].sizes().vec();
std::vector<size_t> scale_shape{scale_sizes.begin(), scale_sizes.end()};
t_scales[i].data.shape = scale_shape;
t_scales[i].data.dtype = DType::kFloat32;
t_scale_invs[i].data.dptr = scale_invs[i].data_ptr();
auto scale_inv_sizes = scale_invs[i].sizes().vec();
std::vector<size_t> scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()};
t_scale_invs[i].data.shape = scale_inv_shape;
t_scale_invs[i].data.dtype = DType::kFloat32;
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
te_scale_invs[i] = reinterpret_cast<NVTETensor>(&t_scale_invs[i]);
}
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(),
te_amax_histories,
te_scales,
te_scale_invs,
amax_compute_algo.c_str(),
static_cast<NVTEDType>(fp8_dtype),
margin,
......
......@@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input,
}
void fused_cast_transpose_noop(at::Tensor input,
at::Tensor noop,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(),
output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
......@@ -319,3 +348,39 @@ at::Tensor fp8_transpose(at::Tensor input,
return output;
}
void fp8_transpose_noalloc(at::Tensor input,
at::Tensor output,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void fp8_transpose_noalloc_noop(at::Tensor input,
at::Tensor output,
at::Tensor noop,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose_with_noop(
input_cu.data(), noop_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -5,10 +5,10 @@
"""Methods needed for distributed training (DP/TP)."""
import warnings
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, Union, Optional, Callable, Tuple
from typing import Any, Dict, List, Union, Optional, Callable, Tuple
import torch
from torch.cuda import _lazy_call
from torch.cuda import _lazy_call, _lazy_init
from torch.utils.checkpoint import detach_variable, noop_context_fn
from .utils import safely_set_viewless_tensor_data
......@@ -31,15 +31,60 @@ _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False
def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None:
"""Sets the random number generator state of the current GPU.
_ALL_ACTIVE_RNG_STATES = {}
def get_all_rng_states() -> bool:
"""Returns all generator states used by `CudaRNGStatesTracker`."""
return _ALL_ACTIVE_RNG_STATES
def set_all_rng_states(states: List) -> None:
"""Updates all generator states used by `CudaRNGStatesTracker`."""
global _ALL_ACTIVE_RNG_STATES
_ALL_ACTIVE_RNG_STATES = states
def graph_safe_rng_available() -> bool:
"""Returns whether cuda graph safe RNG state manipulation is supported."""
return (hasattr(torch.cuda.CUDAGraph, "register_generator_state")
and hasattr(torch.Generator, "graphsafe_set_state")
and hasattr(torch.Generator, "graphsafe_get_state")
and hasattr(torch.Generator, "clone_state"))
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda",
clone: bool = False,
graph_safe: bool = True,
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU."""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
if clone:
# Reference to the cloned generator state
return default_generator.clone_state()
# Reference to the current generator state
return default_generator.graphsafe_get_state()
return default_generator.get_state()
def _set_cuda_rng_state(
new_state: torch.Tensor,
device: Union[int, str] = -1,
graph_safe = True,
) -> None:
"""Sets the random number generator state of the current GPU."""
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if device == -1:
device = torch.device("cuda")
elif isinstance(device, str):
......@@ -52,6 +97,9 @@ def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
default_generator.graphsafe_set_state(new_state)
return
default_generator.set_state(new_state)
_lazy_call(cb)
......@@ -206,7 +254,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
......@@ -271,13 +319,13 @@ class _CheckpointFunction(torch.autograd.Function):
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
......@@ -291,7 +339,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
......@@ -317,6 +365,7 @@ class _CheckpointFunction(torch.autograd.Function):
)
return (None, None, None, None, None, None) + grads
class _CheckpointFrame:
"""
Storage frame for forward RNG states and detached activations from the forward recompute.
......@@ -338,7 +387,7 @@ class _CheckpointFrame:
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = (
torch.get_rng_state(),
torch.cuda.get_rng_state(),
_get_cuda_rng_state(graph_safe=False),
)
if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(), )
......@@ -356,7 +405,7 @@ class _CheckpointFrame:
rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1])
_set_cuda_rng_state(rng_states[1], graph_safe=False)
if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2])
......@@ -604,6 +653,7 @@ def checkpoint(
return out
class CudaRNGStatesTracker:
"""
For model parallelism, multiple RNG states need to simultaneously exist in order
......@@ -664,13 +714,23 @@ class CudaRNGStatesTracker:
# Check that state is not already defined.
if name in self.states_:
raise Exception(f"cuda rng state {name} already exists")
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
if graph_safe_rng_available():
new_state = _get_cuda_rng_state(clone=True)
new_state.manual_seed(seed)
self.states_[name] = new_state
# Update global states.
set_all_rng_states(self.states_)
else:
# Get the current rng state.
orig_rng_state = _get_cuda_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = _get_cuda_rng_state(clone=True)
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
# Update global states.
set_all_rng_states(self.states_)
@contextmanager
def fork(self, name: str = "model-parallel-rng"):
......@@ -684,16 +744,17 @@ class CudaRNGStatesTracker:
# Check if we have added the state
if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added")
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Get the reference to current rng state.
orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# this is redundant with graph-safe API
if not graph_safe_rng_available():
self.states_[name] = _get_cuda_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
......
......@@ -16,6 +16,7 @@ from .fp8 import FP8GlobalStateManager
aten = torch.ops.aten
c10d = torch.ops.c10d
updated_fp8_params = {}
def _make_fp8_attr_property_funcs(name: str) -> Any:
......@@ -67,6 +68,31 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None
def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)
if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return
autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]
if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return
if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}
current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(
forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]
class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
@staticmethod
......@@ -167,6 +193,7 @@ class _ToFloat8Func(torch.autograd.Function):
# Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
......@@ -307,8 +334,9 @@ class Float8Tensor(torch.Tensor):
), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: tex.DType = fp8_dtype
# Cached transpose
# Transposed version of `_data`.
self._transpose: Optional[Float8Tensor] = None
self._transpose_invalid: bool = True
# FP8 scale-inverse
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
......@@ -435,80 +463,51 @@ class Float8Tensor(torch.Tensor):
return _IdentityFunc.apply(self)
return super().expand_as(other)
def transpose(
def transpose_2d(
self,
dim0: int = 0,
dim1: int = 1,
*,
update_cache: str | bool = "reuse_only",
cache: bool = False,
noop_flag: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Swap tensor dimensions
For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
2D transpose with caching support.
Parameters
----------
dim0: int, default = 0
The first dimension to be transposed
dim1: int, default = 1
The second dimension to be transposed
update_cache: str or bool, default = "reuse_only"
Memoization behavior. Options are
"reuse_only"/`False` (reuse cached value if
available, otherwise calculate transpose without
caching), "force"/`True` (calculate transpose
and cache), "lazy" (reuse cached value if
available, otherwise calculate transpose and
cache if possible). Caching is only supported
for basic 2D transposes and the cache is reset
after any in-place operations.
cache: bool, default = `False`
Whether or not to cache the transpose.
noop_flag: Optional[torch.Tensor], default = `None`
Only used if argument `cache` is `True`, ignored otherwise.
A single element fp32 tensor with a value of 1.0 or 0.0
which is treated as a boolean. `1.0` forces recompute
and `0.0` executes a noop using the same kernel.
"""
assert self.dim() == 2, f"{self.dim()}-D transpose not supported."
# Check caching mode
if not isinstance(update_cache, str):
update_cache = "force" if update_cache else "reuse_only"
if update_cache not in ("force", "reuse_only", "lazy"):
raise ValueError(
"Supported values for update_cache are "
'"force" (True), "reuse_only" (False), "lazy" '
f"(got {update_cache})"
)
# Case: no caching.
if not cache:
return tex.fp8_transpose(self._data, self._fp8_dtype)
# Handle non-2D transposes
if -self.dim() <= dim0 < 0:
dim0 += self.dim()
if -self.dim() <= dim1 < 0:
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache == "force":
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
)
return super().transpose(dim0, dim1)
# Clear cache if needed
if update_cache == "force":
self._transpose = None
# Compute transpose if needed
out = self._transpose
if out is None:
out = Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous(),
self._fp8_dtype,
),
)
# Case: reuse cache without calling a kernel.
if not self._transpose_invalid and noop_flag is None:
assert self._transpose is not None, "Tranpose cache is empty."
return self._transpose
# Update cache if needed
if update_cache in ("force", "lazy"):
self._transpose = out
return out
# Allocate transpose if needed.
data_2d = self._data.reshape(-1, self._data.shape[-1])
if self._transpose is None:
shape = (data_2d.shape[1], data_2d.shape[0])
self._transpose = torch.empty(shape, dtype=torch.uint8, device=self._data.device)
# Case: recompute transpose and store cache.
if noop_flag is None:
tex.fp8_transpose_noalloc(data_2d, self._transpose, self._fp8_dtype)
else:
# Case: cuda graph capture.
tex.fp8_transpose_noalloc_noop(data_2d, self._transpose, noop_flag, self._fp8_dtype)
self._transpose_invalid = False
return self._transpose
@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
......@@ -519,13 +518,11 @@ class Float8Tensor(torch.Tensor):
the tensor.
"""
if self._fp8_meta is None:
return
assert self._fp8_meta is not None, "FP8 meta tensors not found."
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
scale_inv.view(1).copy_(self._scale_inv.view(1))
self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0])
def to_dtype(self, dtype: torch.dtype) -> Float8Tensor:
"""Create `Float8Tensor` with given nominal dtype
......@@ -541,12 +538,11 @@ class Float8Tensor(torch.Tensor):
)
def _reset_caches(self) -> None:
"""Reset cached values
"""
Set transpose cache as invalid.
Should be called after any in-place operation.
"""
self._transpose = None
self._transpose_invalid = True
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......@@ -574,7 +570,7 @@ class Float8Tensor(torch.Tensor):
# Directly copy FP8 data if possible
if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone()
dst._scale_inv.copy_(src._scale_inv.detach())
if dst._fp8_meta is not None:
if src._fp8_meta is None:
src_min, src_max = src.from_float8().aminmax()
......@@ -600,7 +596,6 @@ class Float8Tensor(torch.Tensor):
dst.copy_(src.from_float8())
elif dst_is_fp8 and not src_is_fp8:
# Make sure input is in expected format
src = src.expand(dst.size())
src = src.to(
......@@ -619,7 +614,7 @@ class Float8Tensor(torch.Tensor):
fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv = scale.detach().view(1).reciprocal()
dst._scale_inv.copy_(scale.detach().reciprocal())
# Cast to FP8
if not dst._data.is_contiguous():
......@@ -633,6 +628,9 @@ class Float8Tensor(torch.Tensor):
dst._fp8_dtype,
)
# This branch is where the FP8 parameters are updated in-place during optimization.
# Handle forward amax reduction.
post_optimizer_step_fwd_amax_reduction(dst)
else:
# Invalid case
......@@ -641,6 +639,7 @@ class Float8Tensor(torch.Tensor):
# Nothing to return for in-place ops
if dst_is_fp8:
dst._reset_caches()
return None
# Slice op
......@@ -764,6 +763,7 @@ class Float8Tensor(torch.Tensor):
_fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index"))
_fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype"))
_transpose = property(**_make_fp8_attr_property_funcs("transpose"))
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
# Do not force the Float8Tensor type on the returned tensor
......
......@@ -51,6 +51,17 @@ def get_fp8_te_dtype(
return tex.DType.kFloat8E5M2
def get_fp8_max(
fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
"""Get max representible FP8 value."""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return Format.E4M3.value.max_fwd
return Format.E5M2.value.max_fwd
class FP8GlobalStateManager:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
......@@ -61,20 +72,21 @@ class FP8GlobalStateManager:
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
IS_FIRST_FP8_MODULE = False
FP8_AUTOCAST_COUNTER = 0
FP8_CURRENT_CONTEXT_ID = 0
FP8_GRAPH_CAPTURING = False
FP8_AUTOCAST_DEPTH = 0
global_fp8_buffer = {}
global_amax_buffer = {}
global_amax_history_buffer = {}
global_scale_buffer = {}
global_scale_inv_buffer = {}
fp8_tensors_recompute_buffer = []
amax_forward_global_reduce_func = None
buffer_delete_key_fwd = None
buffer_delete_key_bwd = None
amax_reduce_handle_fwd = None
fp8_available = None
reason_for_no_fp8 = ""
dp_amax_reduce_interval = None
dp_amax_reduce_forward_idx = 0
dp_amax_reduce_backward_idx = 0
multi_grad_hook_tensors = []
bwd_amax_update_hook_registered = False
autocast_arguments = {}
autocast_to_fp8_params = {}
fp8_param_to_autocast = {}
skip_fp8_weight_update_tensor = None
@classmethod
def reset(cls) -> None:
......@@ -83,21 +95,35 @@ class FP8GlobalStateManager:
cls.FP8_CALIBRATION = False
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.FP8_PARAMETERS = False
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_AUTOCAST_COUNTER = 0
cls.FP8_CURRENT_CONTEXT_ID = 0
cls.FP8_GRAPH_CAPTURING = False
cls.FP8_AUTOCAST_DEPTH = 0
cls.global_fp8_buffer = {}
cls.global_amax_buffer = {}
cls.global_amax_history_buffer = {}
cls.global_scale_buffer = {}
cls.global_scale_inv_buffer = {}
cls.fp8_tensors_recompute_buffer = []
cls.amax_forward_global_reduce_func = None
cls.buffer_delete_key_fwd = None
cls.buffer_delete_key_bwd = None
cls.amax_reduce_handle_fwd = None
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.dp_amax_reduce_interval = None
cls.dp_amax_reduce_forward_idx = 0
cls.dp_amax_reduce_backward_idx = 0
cls.multi_grad_hook_tensors = []
cls.bwd_amax_update_hook_registered = False
cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
cls.skip_fp8_weight_update_tensor = None
@classmethod
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
"""`skip_fp8_weight_update_tensor` inplace setter."""
if cls.skip_fp8_weight_update_tensor is None:
cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
cls.skip_fp8_weight_update_tensor.fill_(skip)
@classmethod
def get_skip_fp8_weight_update_tensor(cls) -> None:
"""`skip_fp8_weight_update_tensor` getter."""
return cls.skip_fp8_weight_update_tensor
@classmethod
def is_fp8_available(cls) -> Tuple[bool, str]:
......@@ -106,44 +132,6 @@ class FP8GlobalStateManager:
cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
return cls.fp8_available, cls.reason_for_no_fp8
@classmethod
def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]:
"""Returns global fp8 state variables."""
# Convert attributes to dictionary to make future proof against
# changes in global state variables in order to make setting the
# checkpoint backwards compatible.
global_fp8_state = {}
global_fp8_state["FP8_AUTOCAST_COUNTER"] = cls.FP8_AUTOCAST_COUNTER
global_fp8_state["FP8_CURRENT_CONTEXT_ID"] = cls.FP8_CURRENT_CONTEXT_ID
global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH
global_fp8_state["buffer_delete_key_fwd"] = cls.buffer_delete_key_fwd
global_fp8_state["buffer_delete_key_bwd"] = cls.buffer_delete_key_bwd
global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval
global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx
global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx
return global_fp8_state
@classmethod
def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> None:
"""Sets global fp8 state variables."""
for k, v in state.items():
if hasattr(cls, k):
setattr(cls, k, v)
@classmethod
def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 amax buffer."""
return cls.global_fp8_buffer
@classmethod
def set_global_fp8_buffer_checkpoint(cls, buffer: Dict[str, List[torch.Tensor]]) -> None:
"""Sets global fp8 amax buffer."""
# Map all tensors back to GPU.
for k, v in buffer.items():
buffer[k] = [tensor.cuda() for tensor in v]
cls.global_fp8_buffer = buffer
@staticmethod
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
......@@ -152,121 +140,102 @@ class FP8GlobalStateManager:
return "scaling_bwd"
@staticmethod
def get_buffer_position_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "global_fp8_buffer_pos_fwd"
return "global_fp8_buffer_pos_bwd"
@staticmethod
def get_autocast_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "autocast_id_fwd"
return "autocast_id_bwd"
@staticmethod
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"
def get_fwd_bwd_key(forward: bool = True) -> str:
"""Convert bool `forward` to string."""
return "forward" if forward else "backward"
@classmethod
def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop."""
return cls.amax_reduce_handle_fwd
def get_buffer_info(cls) -> str:
"""
Returns a key for `fp8_meta` that stores the module's index
in the global buffers along with autocast information.
"""
return "buffer_index_and_autocast_key"
@classmethod
def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None:
"""Sets up the function to call during autocast exit."""
cls.amax_forward_global_reduce_func = f
def get_key_in_buffer(
cls,
forward: bool,
fp8_weights: bool,
fp8_recipe: DelayedScaling,
fp8_group: dist_group_type,
) -> str:
"""Returns a key into the global FP8 buffers."""
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
fwd_bwd_key = cls.get_fwd_bwd_key(forward)
return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}"
@classmethod
def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None:
"""Append 1D tensor `amax` to global buffer."""
buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
buffer_position_key = cls.get_buffer_position_key(forward=forward)
if buffer_key not in cls.global_fp8_buffer:
cls.global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else:
cls.global_fp8_buffer[buffer_key].append(
fp8_meta[fp8_meta_tensor_key].amax_history[0]
)
if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(cls.global_fp8_buffer[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(cls.global_fp8_buffer[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` " \
"region when using FP8 with amax reduction. This behavior is currently" \
" unsupported. For more details and correct usage, please see " \
"https://github.com/NVIDIA/TransformerEngine/pull/93."
def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]:
"""Splits buffer key into relevant parts."""
forward, fp8_weights, autocast_key = key.split("_", 2)
forward = forward == "forward"
fp8_weights = fp8_weights == "True"
return forward, fp8_weights, autocast_key
@classmethod
def copy_amax_from_global_buffer(
cls, fp8_meta: Dict[str, Any], forward: bool = True
def add_fp8_tensors_to_global_buffer(
cls,
fp8_meta: Dict[str, Any],
fp8_weights: Optional[List[torch.Tensor]] = None,
) -> None:
"""Populate current amax with the correct location from buffer."""
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
buffer_position_key = cls.get_buffer_position_key(forward=forward)
if buffer_position_key not in fp8_meta:
return
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
assert amax_buffer_key in cls.global_fp8_buffer, "TE internal error."
fp8_meta[fp8_meta_tensor_key].amax_history[0] = cls.global_fp8_buffer[amax_buffer_key][
fp8_meta[buffer_position_key]
]
"""
The amax reduction process happens completely outside the FP8 modules.
To participate in the reduction, the only role played by a module is
to call this function in order to append it's FP8 tensor into a global
buffer. There are 5 global buffers maintained, one each for amax, amax
history, scale, scale-inverse, and non-weight-mask. Each buffer has
keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
to indicate the type of FP8 tensor, since the forward and backward
reductions happen separately.
Note: For CG capture, this method is called from the graphed
wrapper. For non CG case, it's called from within the module.
"""
@classmethod
def set_amax_buffer_key_deletion(
cls, fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Delete this amax key from global buffer during autocast end."""
if cls.get_autocast_key(forward=forward) not in fp8_meta:
# Every module must call this function exactly once since
# the amax tensors are static. Ensures that compatibility
# with non-graphed modules is maintained.
index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors.
if index_in_buffer in fp8_meta:
return
if forward:
cls.buffer_delete_key_fwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
else:
cls.buffer_delete_key_bwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
@classmethod
def delete_key_from_amax_buffer(cls, forward: bool = True) -> None:
"""Delete the key from global amax buffer."""
if forward:
if (
cls.buffer_delete_key_fwd is not None
and cls.buffer_delete_key_fwd in cls.global_fp8_buffer
):
del cls.global_fp8_buffer[cls.buffer_delete_key_fwd]
else:
if (
cls.buffer_delete_key_bwd is not None
and cls.buffer_delete_key_bwd in cls.global_fp8_buffer
):
del cls.global_fp8_buffer[cls.buffer_delete_key_bwd]
@classmethod
def get_fp8_context_id(cls) -> int:
"""Returns an ID for the current FP8 context."""
return cls.FP8_CURRENT_CONTEXT_ID
@classmethod
def set_fp8_context_id(cls, ctx_id: int) -> None:
"""Sets the current FP8 context."""
cls.FP8_CURRENT_CONTEXT_ID = ctx_id
@classmethod
def new_fp8_context_id(cls) -> int:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return cls.FP8_AUTOCAST_COUNTER
fp8_meta[index_in_buffer] = []
for forward in (True, False):
# This algorithm creates a two-way map with `autocast_to_fp8_params` and
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
# in an autocasted region and cross reference them in `float8_tensor.py`
# to perform the forward amax reduction.
if forward and fp8_weights is not None:
autocast_key = cls.get_unique_autocast_key(
fp8_meta["recipe"], fp8_meta["fp8_group"])
fp8_weight_set = {id(w._data) for w in fp8_weights}
if autocast_key not in cls.autocast_to_fp8_params:
cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set
else:
cls.autocast_to_fp8_params[autocast_key] = (
cls.autocast_to_fp8_params[autocast_key].union(fp8_weight_set))
# Identify correct autocast key for a given param.
for w in fp8_weight_set:
cls.fp8_param_to_autocast[w] = autocast_key
key = cls.get_key_in_buffer(
forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"])
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history]
cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale]
cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv]
else:
cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
cls.global_amax_history_buffer[key].append(
fp8_meta[fp8_meta_tensor_key].amax_history)
cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv)
fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
fp8_meta[index_in_buffer].append(key)
@classmethod
def is_fp8_enabled(cls) -> bool:
......@@ -283,6 +252,11 @@ class FP8GlobalStateManager:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS
@classmethod
def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()
@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
......@@ -310,7 +284,8 @@ class FP8GlobalStateManager:
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE)
cls.IS_FIRST_FP8_MODULE,
cls.FP8_GRAPH_CAPTURING)
@classmethod
def set_fp8_autocast_state(
......@@ -322,80 +297,100 @@ class FP8GlobalStateManager:
cls.FP8_CALIBRATION,
cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE) = fp8_state
cls.IS_FIRST_FP8_MODULE,
cls.FP8_GRAPH_CAPTURING) = fp8_state
@staticmethod
def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type, async_op: bool
tensor: torch.Tensor, group: dist_group_type
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
wait_handle = torch.distributed.all_reduce(
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=async_op,
async_op=False,
)
return wait_handle
return None
@classmethod
def global_amax_reduction(
def reduce_and_update_fp8_tensors(
cls,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True,
fp8_weights: bool = False,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward)
# Key already deleted.
if amax_buffer_key not in cls.global_fp8_buffer:
return None
# Reduce AMAX in DP-domain at an interval.
# `NVTE_DP_AMAX_REDUCE_INTERVAL` should be set as an integer value larger than 0. If
# `NVTE_DP_AMAX_REDUCE_INTERVAL` is set to 0, AMAX is reduced only in TP domain.
if cls.dp_amax_reduce_interval is None:
cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
if cls.dp_amax_reduce_interval == 0:
tp_amax_reduce = True
else:
tp_amax_reduce = False
if forward:
if cls.dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
cls.dp_amax_reduce_forward_idx = (
(cls.dp_amax_reduce_forward_idx + 1) % cls.dp_amax_reduce_interval)
for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction.
fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key)
if fwd_update != forward:
continue
# Only skip a forward update when `fp8_weights` is explicitly set to `True`
# (inside optimizer) and the current key is not an `fp8_weight_update` key.
# For other cases, we need to reduce because of activation tensors.
# TODO(ksivaman) consider separate weight and activation fp8_tensors.
if fwd_update and fp8_weights and not fp8_weights_update:
continue
if len(amax_buffer) == 0:
continue
# Retrieve autocast specific args and concat amaxes.
recipe, group = cls.autocast_arguments[autocast_key]
contiguous_amax = torch.cat(amax_buffer)
# Reduction.
if (recipe.reduce_amax
and torch.distributed.is_initialized()
and torch.distributed.get_world_size(group=group) > 1):
cls.reduce_tensor_across_group_op_max(contiguous_amax, group)
# Amax and scale update.
unfused_update = (bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0")))
or callable(recipe.amax_compute_algo)
or callable(recipe.scaling_factor_compute_algo))
if not unfused_update:
tex.fused_amax_and_scale_update_after_reduction(
contiguous_amax,
cls.global_amax_history_buffer[buffer_key],
cls.global_scale_buffer[buffer_key],
cls.global_scale_inv_buffer[buffer_key],
recipe.amax_compute_algo,
get_fp8_te_dtype(recipe, forward),
recipe.margin,
)
else:
if cls.dp_amax_reduce_backward_idx == 0:
reduce_group = fp8_meta["fp8_group"]
else:
tp_amax_reduce = True
cls.dp_amax_reduce_backward_idx = (
(cls.dp_amax_reduce_backward_idx + 1) % cls.dp_amax_reduce_interval)
split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer])
if tp_amax_reduce:
if tp_size > 1:
reduce_group = tp_group
else:
return None
for amax_history, scale, scale_inv in zip(
cls.global_amax_history_buffer[buffer_key],
cls.global_scale_buffer[buffer_key],
cls.global_scale_inv_buffer[buffer_key],
):
_amax_and_scale_update(
amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe)
chunk_sizes = [x.numel() for x in cls.global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key])
@classmethod
def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor):
"""Add tensor to list for multi grad hook."""
cls.multi_grad_hook_tensors.append(tensor)
wait_handle = cls.reduce_tensor_across_group_op_max(
contiguous_amax,
reduce_group,
fp8_meta["async_amax_reduction"],
)
@classmethod
def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument
"""Executes at the end of backward pass."""
cls.reduce_and_update_fp8_tensors(forward=False)
cls.global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
return wait_handle
@classmethod
def get_unique_autocast_key(
cls,
recipe: Optional[DelayedScaling] = None,
group: Optional[dist_group_type] = None,
):
"""
For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
Safely using `hash` as we never cross checkpoint boundaries.
"""
return f"{str(recipe)}:{hash(group)}"
@classmethod
def fp8_autocast_enter(
......@@ -404,21 +399,29 @@ class FP8GlobalStateManager:
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
"""Set state and tracking variables for entry into FP8 region."""
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func):
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable
cls.delete_key_from_amax_buffer(forward=True)
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0:
# This hook does not fire for graphed modules.
torch.autograd.graph.register_multi_grad_hook(
tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction)
cls.bwd_amax_update_hook_registered = True
cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
cls.FP8_RECIPE = fp8_recipe
cls.FP8_DISTRIBUTED_GROUP = fp8_group
cls.FP8_GRAPH_CAPTURING = _graph
if cls.FP8_AUTOCAST_DEPTH == 0:
cls.IS_FIRST_FP8_MODULE = True
cls.FP8_AUTOCAST_COUNTER += 1
cls.FP8_AUTOCAST_DEPTH += 1
if enabled:
......@@ -426,9 +429,14 @@ class FP8GlobalStateManager:
assert fp8_available, reason_for_no_fp8
@classmethod
def fp8_autocast_exit(cls):
def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
"""Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1
# Reduce only the non-FP8 weight modules here.
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False)
@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
......@@ -525,6 +533,7 @@ def fp8_autocast(
calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None:
"""
Context manager for FP8 usage.
......@@ -568,23 +577,25 @@ def fp8_autocast(
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group)
fp8_group=fp8_group,
_graph=_graph)
yield
finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
FP8GlobalStateManager.fp8_autocast_exit()
FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1:
amax_history = torch.roll(amax_history, -1, 0)
new_amax_history = torch.roll(amax_history, -1, 0)
amax_history.copy_(new_amax_history)
amax_history[0].fill_(0.0)
return amax_history
@torch.jit.script
def _default_get_amax(
def _default_get_amax_and_update_history(
amax_history: torch.Tensor,
amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -609,63 +620,23 @@ def _default_sf_compute(
sf = (fp8_max / amax) / (2 ** margin)
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
return sf
@jit_fuser
def _compute_scaling_factor_inverse(
scale: torch.Tensor,
scale_inv: torch.Tensor,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> torch.Tensor:
"""Compute inverse of scaling factor."""
if update_weight_scale_inv:
return 1.0 / scale
return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
def _fused_amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
scale_inv: torch.Tensor,
fp8_dtype: tex.DType,
margin: int,
amax_compute_algo: str,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Update amax history and FP8 scaling factors"""
if update_weight_scale_inv:
non_weight_mask = torch.Tensor()
tex.fused_amax_and_scale_update(
amax_history,
scale,
scale_inv,
non_weight_mask,
amax_history,
scale,
scale_inv,
amax_compute_algo,
fp8_dtype,
margin,
)
return amax_history, scale, scale_inv
scale.copy_(sf)
return scale
def _compute_amax(
def _compute_amax_and_update_history(
amax_history: torch.Tensor,
recipe: DelayedScaling,
amax_compute_algo: Union[Callable, str],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Obtain the amax from the history."""
if callable(recipe.amax_compute_algo):
amax = recipe.amax_compute_algo(amax_history)
if callable(amax_compute_algo):
amax = amax_compute_algo(amax_history)
amax_history = _update_amax_history(amax_history)
return amax_history, amax
return _default_get_amax(
return _default_get_amax_and_update_history(
amax_history,
recipe.amax_compute_algo,
amax_compute_algo,
)
......@@ -687,46 +658,29 @@ def _compute_scaling_factor(
return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)
def amax_and_scale_update(
fp8_meta: Dict[str, Any],
fwd_update: bool,
update_weight_scale_inv: bool = True,
def _amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
scale_inv: torch.Tensor,
fp8_max: float,
recipe: DelayedScaling,
) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
) = _fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
get_fp8_te_dtype(fp8_meta["recipe"], fwd_update),
fp8_meta["recipe"].margin,
fp8_meta["recipe"].amax_compute_algo,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)
else:
fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta["recipe"],
)
fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor(
amax,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_max_key],
fp8_meta["recipe"],
)
fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse(
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)
"""Updates FP8 meta tensors."""
new_amax_history, amax = _compute_amax_and_update_history(
amax_history,
recipe.amax_compute_algo,
)
new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe)
scale.copy_(new_scale)
scale_inv.copy_(1.0 / new_scale)
amax_history.copy_(new_amax_history)
def split_and_copy(
buffer: torch.Tensor,
outputs: List[torch.Tensor],
chunk_sizes: List[int],
) -> None:
"""Split `buffer` by `chunk_sizes` and copy into `outputs`."""
splits = buffer.split(chunk_sizes)
torch._foreach_copy_(outputs, splits)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Functions for CUDA Graphs support in FP8"""
import torch
from torch.utils._pytree import tree_flatten as _tree_flatten
from torch.utils._pytree import tree_unflatten as _tree_unflatten
from torch._C import _graph_pool_handle
from .fp8 import (
fp8_autocast,
FP8GlobalStateManager,
get_default_fp8_recipe,
)
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule
__all__ = ["make_graphed_callables"]
_IS_GRAPH_CAPTURING = False
def set_capture_start() -> None:
"""Record beginning of `make_graphed_callables`."""
global _IS_GRAPH_CAPTURING
_IS_GRAPH_CAPTURING = True
def set_capture_end() -> None:
"""Record end of `make_graphed_callables`."""
global _IS_GRAPH_CAPTURING
_IS_GRAPH_CAPTURING = False
def is_graph_capturing() -> None:
"""Return whether within `make_graphed_callables`."""
return _IS_GRAPH_CAPTURING
def graph_pool_handle():
"""
Returns an opaque token representing the id of a graph memory pool.
"""
return _graph_pool_handle()
def _make_graphed_callables(
callables,
sample_args,
num_warmup_iters=3,
allow_unused_input=False,
fp8_weight_caching=False,
_order=None,
):
"""
Helper method for `make_graphed_callables`
"""
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
raise RuntimeError(
"make_graphed_callables does not support the autocast "
"caching. Please set `cache_enabled=False`."
)
just_one_callable = False
if not isinstance(callables, tuple):
just_one_callable = True
callables = (callables,)
sample_args = (sample_args,)
flatten_sample_args = []
if _order is not None:
# order is a list containing 1..model_chunk values in the order of microbatch schedule
num_model_chunks = max(_order)
num_microbatches = len(_order) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order)
assert (
len(sample_args)*2 >= len(_order)
and (len(sample_args)*2 % len(_order) == 0)
), f'{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0'
num_layers = len(sample_args) // num_model_chunks // num_microbatches
assert (
len(callables) == num_model_chunks*num_layers
), (f"Callables should have ({num_model_chunks * num_layers}) "
+ f"entries when order input is provided but got {len(callables)}."
)
assert (
len(sample_args) == num_model_chunks * num_microbatches * num_layers
), (f"Expected {num_model_chunks * num_microbatches}"
+ f"args tuple, but got {len(sample_args)}."
)
if fp8_weight_caching:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
for c in callables:
if isinstance(c, torch.nn.Module):
assert (
len(c._backward_hooks) == 0
and len(c._forward_hooks) == 0
and len(c._forward_pre_hooks) == 0
), (
"Modules must not have hooks registered at the time they are passed. "
+ "However, registering hooks on modules after passing them "
+ "through make_graphed_callables is allowed."
)
assert all(b.requires_grad is False for b in c.buffers()), (
"In any :class:`~torch.nn.Module` passed to "
+ ":func:`~make_graphed_callables`, only parameters may be trainable. "
+ "All buffers must have ``requires_grad=False``."
)
for args in sample_args:
flatten_arg, _ = _tree_flatten(args)
flatten_sample_args.append(tuple(flatten_arg))
assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
"In the beta API, sample_args "
+ "for each callable must contain only Tensors. Other types are not allowed."
)
# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
per_callable_len_user_args = [len(args) for args in flatten_sample_args]
if _order is None:
per_callable_module_params = [
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
for c in callables
]
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
for i in range(len(callables))
]
else:
per_callable_module_params = []
for c in callables:
for i in range(num_microbatches):
per_callable_module_params.append(
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
)
assert len(per_callable_module_params) == len(flatten_sample_args)
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
for i in range(len(flatten_sample_args))
]
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
graph_callables = [None for _ in range(len(flatten_sample_args))]
# For cases with multiple active RNG states, e.g. TP.
if graph_safe_rng_available():
for _, state in get_all_rng_states().items():
for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs):
fwd_graph.register_generator_state(state)
bwd_graph.register_generator_state(state)
mempool = graph_pool_handle()
# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):
for c_i, func in enumerate(callables):
args = sample_args[c_i]
static_input_surface = per_callable_static_input_surfaces[c_i]
for _ in range(num_warmup_iters):
outputs, _ = _tree_flatten(func(*args))
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
only_inputs=True,
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
torch.cuda.synchronize()
# All captures here share a mempool. To avoid replays corrupting each other's memory,
# the safest approach is to capture all passes in the same order they'll run:
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
if _order is not None: # pylint: disable=too-many-nested-blocks
per_callable_static_outputs = [None] * len(flatten_sample_args)
per_callable_output_unflatten_spec = [None] * len(flatten_sample_args)
per_callable_static_grad_outputs = [None] * len(flatten_sample_args)
per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id-1
for l_no in range(num_layers):
func = callables[m_chunk*num_layers + l_no]
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) \
+ (fwd_idx[m_chunk] * num_layers + l_no)
args = sample_args[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
graph_callables[per_callable_fwd_idx] = func
fwd_idx[m_chunk] += 1
else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id-1
for l_no in list(reversed(range(num_layers))):
per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) \
+ (bwd_idx[m_chunk] * num_layers + l_no)
static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None) # type: ignore[arg-type]
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
bwd_idx[m_chunk] += 1
else:
# Capture forward graphs
per_callable_static_outputs = []
per_callable_output_unflatten_spec = []
graph_id = 0
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
graph_callables[graph_id] = func
graph_id += 1
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs.append(tuple(flatten_outputs))
per_callable_output_unflatten_spec.append(spec)
# Capture backward graphs in reverse order
per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph in zip(
reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs),
reversed(bwd_graphs),
):
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None) # type: ignore[arg-type]
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
per_callable_static_grad_outputs.append(static_grad_outputs)
per_callable_static_grad_inputs.append(static_grad_inputs)
# Reverses the most recent two lists
per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
def make_graphed_autograd_function(
fwd_graph,
bwd_graph,
module_params,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
static_grad_outputs,
static_grad_inputs,
):
class Graphed(torch.autograd.Function):
"""Autograd function for graph replay."""
@staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs):
# At this stage, only the user args may (potentially) be new tensors.
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
if ctx.is_first_module and skip_fp8_weight_update is not None:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *grads):
assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads):
if g is not None:
# don't copy if autograd gods have been kind and the
# incoming grad is already in the right place
if g.data_ptr() != grad.data_ptr():
g.copy_(grad)
bwd_graph.replay()
if ctx.is_first_module:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
# Input args that didn't require grad expect a None gradient.
assert isinstance(static_grad_inputs, tuple)
return (None,) + tuple(
b.detach() if b is not None else b for b in static_grad_inputs
)
def functionalized(*user_args, **user_kwargs):
# Runs the autograd function with inputs == all
# inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
skip_fp8_weight_update = None
if fp8_weight_caching:
assert (
("is_first_microbatch" in user_kwargs
and isinstance(user_kwargs["is_first_microbatch"], bool))
), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching."
skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]
flatten_user_args, _ = _tree_flatten(user_args)
out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params))
return _tree_unflatten(out, output_unflatten_spec)
return functionalized
# Put together the final graphed callables
ret = []
for i in range(len(sample_args)):
graphed = make_graphed_autograd_function(
fwd_graphs[i],
bwd_graphs[i],
per_callable_module_params[i],
per_callable_len_user_args[i],
per_callable_output_unflatten_spec[i],
per_callable_static_input_surfaces[i],
per_callable_static_outputs[i],
per_callable_static_grad_outputs[i],
per_callable_static_grad_inputs[i],
)
func = graph_callables[i]
if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
def new_fwd(*user_args, **user_kwargs):
# If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method
if func.training == graph_training_state:
# Set the FP8 group from global amax reduction.
for m in func.modules():
if (isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()):
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, fp8_weights=m._get_fp8_params())
return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs)
return new_fwd
forward = make_graphed_forward(func, func.training, graphed, func.forward)
if _order is None:
func.forward = forward
ret.append(func)
else:
ret.append(forward)
else:
ret.append(graphed)
if just_one_callable:
return ret[0]
return tuple(ret)
def save_fp8_tensors(modules, amax_history_len):
"""
Returns the FP8 tensors for all modules
with adjusted amax history sizes.
"""
saved_fp8_meta_tensors = []
for module in modules:
for m in module.modules():
if isinstance(m, TransformerEngineBaseModule):
if m.primary_weights_in_fp8:
m.adjust_amax_history_length(amax_history_len)
saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors())
return saved_fp8_meta_tensors
def restore_fp8_tensors(modules, fp8_tensors):
"""Restore FP8 tensors."""
for module in modules:
for m in module.modules():
if isinstance(m, TransformerEngineBaseModule):
m.reset_fp8_meta_tensors(fp8_tensors.pop(0))
assert len(fp8_tensors) == 0, "TE internal error."
def make_graphed_callables(
modules,
sample_args,
num_warmup_iters=3,
allow_unused_input=False,
fp8_enabled=False,
fp8_calibrating=False,
fp8_recipe=None,
fp8_weight_caching=False,
_order=None,
):
"""
A version of PyTorch's `make_graphed_callables` utility function with support for
TransformerEngine modules and FP8. Please see the original version in upstream PyTorch
`here <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
for extensive documentation. The documentation for additional parameters which are
specific to FP8 are given below.
FP8 specific parameters
-----------------------
fp8_enabled: bool, default = `True`
whether or not to enable fp8
fp8_calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_weight_caching: bool, default = `False`
Whether or not to cache FP8 weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
method for TransformerEngine modules. When storing primary weights in FP8
using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg
must be set to `False` if calculating weight transposes' outside TE, e.g.,
in the optimizer step.
"""
set_capture_start()
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
# Handle single module.
just_one_callable = False
if not isinstance(modules, tuple):
just_one_callable = True
modules = (modules,)
# Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len)
# FP8 wrapper.
def wrap_autocast(block):
old_forward = block.forward
def forward_func(*args, **kwargs):
with fp8_autocast(enabled=fp8_enabled,
calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe,
_graph=True):
outputs = old_forward(*args, **kwargs)
return outputs
block.forward = forward_func
forward_funcs = []
for module in modules:
assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported."
wrap_autocast(module)
forward_funcs.append(module)
if just_one_callable:
forward_funcs = forward_funcs[0]
else:
forward_funcs = tuple(forward_funcs)
# Save RNG state.
if graph_safe_rng_available():
generators = [torch.cuda.default_generators[torch.cuda.current_device()],
*get_all_rng_states().values()]
original_rng_states = [state.get_state() for state in generators]
else:
original_rng_states = torch.cuda.get_rng_state()
graphed_callables = _make_graphed_callables(
forward_funcs, sample_args, num_warmup_iters=num_warmup_iters,
allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching, _order=_order)
# Ensures warmup does not affect numerics for ops such as dropout.
if graph_safe_rng_available():
for gen, state in zip(generators, original_rng_states):
gen.set_state(state)
else:
torch.cuda.set_rng_state(original_rng_states)
# Reset FP8 gradients.
for module in modules:
for p in module.parameters():
p.grad = None
# Restore FP8 state.
restore_fp8_tensors(modules, saved_fp8_tensors)
set_capture_end()
return graphed_callables
......@@ -8,8 +8,7 @@ import os
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Generator, Union, Optional, Tuple, Dict, Any, List
from functools import partial
from typing import Generator, Union, Optional, Tuple, List
from contextlib import contextmanager
import torch
......@@ -22,13 +21,11 @@ from ..fp8 import (
get_default_fp8_recipe,
get_fp8_te_dtype,
FP8GlobalStateManager,
amax_and_scale_update,
)
from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
get_distributed_world_size,
)
from ..cpp_extensions import (
fp8_cast_transpose_fused,
......@@ -44,7 +41,6 @@ _2X_ACC_WGRAD = True
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None
layers_atomic_ring_exchange = []
......@@ -64,49 +60,6 @@ def get_workspace() -> torch.Tensor:
)
return _cublas_workspace
@contextmanager
def _prepare_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> Generator[None, None, None]:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction
if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
# From previous iteration
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
else:
amax_and_scale_update(fp8_meta, False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if (fp8 and fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(fp8_meta["fp8_group"]) > 1):
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False)
def initialize_ub(
shape: list,
......@@ -300,31 +253,54 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_size = 1
self.sequence_parallel = False
self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
)
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
.. warning::
This changes the underlying amax memory location.
"""
if fwd is None:
fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
else:
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
for meta_key in fp8_meta_tensor_keys:
curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len:
continue
if length < curr_len:
self.fp8_meta[meta_key].amax_history = (
self.fp8_meta[meta_key].amax_history[: length].clone())
elif length > curr_len:
extra_rows = length - curr_len
self.fp8_meta[meta_key].amax_history = F.pad(
self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
)
# Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = (
self.fp8_meta[FP8GlobalStateManager.get_buffer_info()])
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history[0])
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history)
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
if self.fp8_meta_tensors_initialized:
# Handle changed amax history size.
curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0]
need_len = self.fp8_meta["recipe"].amax_history_len
if need_len < curr_len:
self.fp8_meta[fp8_meta_tensor_key].amax_history = (
self.fp8_meta[fp8_meta_tensor_key]
.amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
)
elif need_len > curr_len:
extra_rows = need_len - curr_len
self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
)
self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd)
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
......@@ -347,25 +323,45 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
device="cuda",
)
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False, True]: -> [input, weight, output]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False, True] * self.fp8_meta["num_gemms"]
).cuda()
else:
# [True, True]: -> [grad_output, grad_input]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes."""
self.set_meta_tensor(True)
self.set_meta_tensor(False)
self.fp8_meta_tensors_initialized = True
def get_fp8_meta_tensors(self) -> None:
"""Get scales and amaxes."""
fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
return None
fp8_meta_tensors = {fwd_key: [], bwd_key: []}
with torch.no_grad():
for key in (fwd_key, bwd_key):
fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
return fp8_meta_tensors
def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
"""Reset scales and amaxes."""
def reset(key):
if key in self.fp8_meta:
if fp8_meta_tensors is None:
self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
self.fp8_meta[key].scale_inv.copy_(
torch.ones_like(self.fp8_meta[key].scale_inv))
self.fp8_meta[key].amax_history.copy_(
torch.zeros_like(self.fp8_meta[key].amax_history))
else:
assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1])
self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2])
with torch.no_grad():
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None
......@@ -380,13 +376,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint()
state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint()
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str, list)):
if isinstance(v, (bool, int, float, str, tuple, list)):
extra[k] = v
state["extra_fp8_variables"] = extra
......@@ -414,11 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# Restore global FP8 amax buffer.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state.
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
......@@ -527,6 +516,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_group = tp_group
self.tp_group_initialized = True
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights."""
fp8_params = []
for param in self.parameters():
if isinstance(param, Float8Tensor) and param.requires_grad:
fp8_params.append(param)
if len(fp8_params) == 0:
return None
return fp8_params
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
......@@ -576,7 +575,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
......@@ -594,49 +592,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if is_first_microbatch is not None and not self.primary_weights_in_fp8:
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8."
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.new_fp8_context_id())
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.get_fp8_context_id())
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.fp8_meta, fp8_weights=self._get_fp8_params())
# Activation recomputation is used and this is the first forward phase.
if (
......@@ -653,18 +616,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return
if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
FP8GlobalStateManager.global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the
......
......@@ -14,7 +14,6 @@ from .. import cpp_extensions as tex
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
......@@ -65,6 +64,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
......@@ -89,6 +89,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool,
ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -98,7 +99,11 @@ class _LayerNormLinear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
......@@ -196,7 +201,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
......@@ -214,6 +218,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
tex.cast_to_fp8(
......@@ -295,6 +300,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_t_fp8,
ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
)
ctx.activation_dtype = activation_dtype
......@@ -321,6 +327,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
......@@ -344,9 +351,7 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
(
inputmat,
ln_weight,
......@@ -357,6 +362,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_t_fp8,
ln_out,
fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
......@@ -364,10 +370,13 @@ class _LayerNormLinear(torch.autograd.Function):
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False
......@@ -472,7 +481,7 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm(
weight_t_fp8._data,
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -686,6 +695,8 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -970,7 +981,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
......@@ -990,6 +1000,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
......@@ -1084,6 +1098,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
......@@ -1132,6 +1150,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
skip_fp8_weight_update,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
......@@ -1156,6 +1175,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag,
self.ub_name,
self.dummy_tensor,
)
out = fwd_fn(*args)
......
......@@ -13,7 +13,6 @@ from torch.nn import init
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
......@@ -94,6 +93,7 @@ class _LayerNormMLP(torch.autograd.Function):
use_fc2_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
......@@ -121,6 +121,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_overlap_rs: bool,
ub_overlap_ag: bool,
gemm_gelu_fusion: bool,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -131,7 +132,11 @@ class _LayerNormMLP(torch.autograd.Function):
assert_dim_for_fp8_exec(fc1_weight)
assert_dim_for_fp8_exec(fc2_weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
activation_func = _act_func(activation)[0]
......@@ -225,8 +230,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.reset_fp8_meta_scale_inv()
fc1_weight_fp8 = fc1_weight
fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
elif update_fp8_weights:
# Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor(
......@@ -250,6 +253,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
cast_out=fc1_weight_fp8._data,
transpose_out=fc1_weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
tex.fp8_cast_transpose_fused(
fc2_weight,
......@@ -258,6 +262,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
cast_out=fc2_weight_fp8._data,
transpose_out=fc2_weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
tex.cast_to_fp8(
......@@ -510,6 +515,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_t_fp8,
fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
)
ctx.activation_dtype = activation_dtype
ctx.activation = activation
......@@ -538,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_overlap_ag = ub_overlap_ag
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear
if ub_overlap_rs:
......@@ -563,9 +570,7 @@ class _LayerNormMLP(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
):
with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
(
inputmat,
ln_weight,
......@@ -582,6 +587,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
......@@ -592,11 +598,18 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8.
update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy"
if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache)
if ctx.fp8 and fc2_weight_t_fp8 is None:
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache)
if ctx.primary_weights_in_fp8:
fc1_weight_t_fp8 = fc1_weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
fc2_weight_t_fp8 = fc2_weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
fc1_weight_t_fp8 = fc1_weight_t_fp8._data
fc2_weight_t_fp8 = fc2_weight_t_fp8._data
activation_func = _act_func(ctx.activation)[1]
......@@ -673,7 +686,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8._data,
fc2_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
......@@ -826,7 +839,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj = None
# FC1 DGRAD: Unconditional
_ = tex.fp8_gemm(
fc1_weight_t_fp8._data,
fc1_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -1151,6 +1164,8 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -1389,7 +1404,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
......@@ -1414,6 +1428,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
......@@ -1473,7 +1491,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
@no_torch_dynamo()
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
......@@ -1497,6 +1517,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
......@@ -1535,6 +1559,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
skip_fp8_weight_update,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
......@@ -1562,6 +1587,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs,
self.ub_overlap_ag,
self.gemm_gelu_fusion,
self.dummy_tensor,
)
out = fwd_fn(*args)
......
......@@ -11,7 +11,6 @@ import transformer_engine_extensions as tex
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
......@@ -65,6 +64,7 @@ class _Linear(torch.autograd.Function):
bias: torch.Tensor,
use_bias: bool,
is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
......@@ -80,7 +80,8 @@ class _Linear(torch.autograd.Function):
primary_weights_in_fp8: bool,
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_name: str
ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -90,7 +91,12 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
......@@ -140,7 +146,6 @@ class _Linear(torch.autograd.Function):
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
......@@ -158,6 +163,7 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
cast_to_fp8(
......@@ -296,6 +302,7 @@ class _Linear(torch.autograd.Function):
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
......@@ -313,6 +320,7 @@ class _Linear(torch.autograd.Function):
ctx.ub_name = ub_name
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear
if ub_overlap_rs:
......@@ -330,9 +338,7 @@ class _Linear(torch.autograd.Function):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
with torch.cuda.nvtx.range("_Linear_backward"):
(
inputmat,
inputmat_t,
......@@ -340,6 +346,7 @@ class _Linear(torch.autograd.Function):
main_grad,
weight_t_fp8,
fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
......@@ -347,10 +354,14 @@ class _Linear(torch.autograd.Function):
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
if ctx.ub_overlap_ag:
......@@ -361,6 +372,7 @@ class _Linear(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
(
grad_output,
grad_output_c,
......@@ -401,7 +413,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
dgrad, _ = fp8_gemm(
weight_t_fp8._data,
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -542,6 +554,8 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -772,7 +786,6 @@ class Linear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
......@@ -785,6 +798,10 @@ class Linear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
......@@ -858,6 +875,10 @@ class Linear(TransformerEngineBaseModule):
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
......@@ -903,6 +924,7 @@ class Linear(TransformerEngineBaseModule):
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
skip_fp8_weight_update,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
......@@ -919,6 +941,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_name,
self.dummy_tensor,
)
out = linear_fn(*args)
......
......@@ -473,6 +473,15 @@ class TransformerLayer(torch.nn.Module):
if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group)
def reset_fp8_meta_tensors(self) -> None:
"""Set TP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "reset_fp8_meta_tensors"):
child.reset_fp8_meta_tensors()
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
......@@ -665,7 +674,8 @@ class TransformerLayer(torch.nn.Module):
# MLP.
mlp_outputs = self.layernorm_mlp(
hidden_states, is_first_microbatch=is_first_microbatch
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment