Unverified Commit 73f8d90f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] cuda graph support (#575)



* FP8 cuda graphs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix numerics
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* exclude torch compile from numerics tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More numerics fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm fusion from unfused path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
parent 1b20f2d6
......@@ -157,7 +157,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
......@@ -238,13 +238,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _stop_comm, 0));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
......@@ -350,11 +347,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
......@@ -469,13 +467,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
_ub_comm->sms = ori_sms;
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
_ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main);
......@@ -506,7 +504,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
}
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(),
......@@ -805,14 +803,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
if (_aggregate2) {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
if (_aggregate2) {
const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
......@@ -877,21 +876,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
}
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id =
(num_steps + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
} else {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
......@@ -936,16 +923,19 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
}
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id = (_tp_size + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
}
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return D;
} // split_overlap_ag
......
......@@ -43,6 +43,7 @@
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
#include <transformer_engine/cast_transpose_noop.h>
namespace transformer_engine {
......
......@@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input,
);
void fused_cast_transpose_noop(at::Tensor input,
at::Tensor noop,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
);
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
......@@ -263,6 +274,17 @@ at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
);
void fp8_transpose_noalloc(at::Tensor input,
at::Tensor output,
transformer_engine::DType otype
);
void fp8_transpose_noalloc_noop(at::Tensor input,
at::Tensor output,
at::Tensor noop,
transformer_engine::DType otype
);
/***************************************************************************************************
* Activations
**************************************************************************************************/
......@@ -559,14 +581,11 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
* FP8 recipe
**************************************************************************************************/
void fused_amax_and_scale_update(const at::Tensor &amax_history,
const at::Tensor &scale,
const at::Tensor &scale_inv,
const at::Tensor &scale_inv_mask,
at::Tensor updated_amax_history,
at::Tensor updated_scale,
at::Tensor updated_scale_inv,
const std::string& amax_compute_algo,
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
std::vector<at::Tensor> scale_invs,
const std::string &amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin);
......
......@@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop,
"Fused Cast + Transpose with noop option");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad,
......@@ -67,6 +69,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop,
"Transpose with FP8 I/O with noop option.");
m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output");
m.def("geglu", &geglu, "GeGLU with FP8 output");
......@@ -82,9 +87,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
m.def("fused_amax_and_scale_update",
&fused_amax_and_scale_update,
"Update amax history and FP8 scale");
m.def("fused_amax_and_scale_update_after_reduction",
&fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction");
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD");
......
......@@ -11,24 +11,50 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
void fused_amax_and_scale_update(const at::Tensor &amax_history,
const at::Tensor &scale,
const at::Tensor &scale_inv,
const at::Tensor &scale_inv_mask,
at::Tensor updated_amax_history,
at::Tensor updated_scale,
at::Tensor updated_scale_inv,
const std::string& amax_compute_algo,
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
std::vector<at::Tensor> scale_invs,
const std::string &amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin) {
nvte_delayed_scaling_recipe_amax_and_scale_update(
makeTransformerEngineTensor(amax_history).data(),
makeTransformerEngineTensor(scale).data(),
makeTransformerEngineTensor(scale_inv).data(),
makeTransformerEngineTensor(scale_inv_mask).data(),
makeTransformerEngineTensor(updated_amax_history).data(),
makeTransformerEngineTensor(updated_scale).data(),
makeTransformerEngineTensor(updated_scale_inv).data(),
using namespace transformer_engine;
size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors);
std::vector<Tensor> t_scale_invs(num_tensors);
std::vector<NVTETensor> te_amax_histories(num_tensors);
std::vector<NVTETensor> te_scales(num_tensors);
std::vector<NVTETensor> te_scale_invs(num_tensors);
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories[i].data.dptr = amax_histories[i].data_ptr();
auto amax_sizes = amax_histories[i].sizes().vec();
std::vector<size_t> amax_shape{amax_sizes.begin(), amax_sizes.end()};
t_amax_histories[i].data.shape = amax_shape;
t_amax_histories[i].data.dtype = DType::kFloat32;
t_scales[i].data.dptr = scales[i].data_ptr();
auto scale_sizes = scales[i].sizes().vec();
std::vector<size_t> scale_shape{scale_sizes.begin(), scale_sizes.end()};
t_scales[i].data.shape = scale_shape;
t_scales[i].data.dtype = DType::kFloat32;
t_scale_invs[i].data.dptr = scale_invs[i].data_ptr();
auto scale_inv_sizes = scale_invs[i].sizes().vec();
std::vector<size_t> scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()};
t_scale_invs[i].data.shape = scale_inv_shape;
t_scale_invs[i].data.dtype = DType::kFloat32;
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
te_scale_invs[i] = reinterpret_cast<NVTETensor>(&t_scale_invs[i]);
}
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(),
te_amax_histories,
te_scales,
te_scale_invs,
amax_compute_algo.c_str(),
static_cast<NVTEDType>(fp8_dtype),
margin,
......
......@@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input,
}
void fused_cast_transpose_noop(at::Tensor input,
at::Tensor noop,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(),
output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
......@@ -319,3 +348,39 @@ at::Tensor fp8_transpose(at::Tensor input,
return output;
}
void fp8_transpose_noalloc(at::Tensor input,
at::Tensor output,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void fp8_transpose_noalloc_noop(at::Tensor input,
at::Tensor output,
at::Tensor noop,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose_with_noop(
input_cu.data(), noop_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -5,10 +5,10 @@
"""Methods needed for distributed training (DP/TP)."""
import warnings
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, Union, Optional, Callable, Tuple
from typing import Any, Dict, List, Union, Optional, Callable, Tuple
import torch
from torch.cuda import _lazy_call
from torch.cuda import _lazy_call, _lazy_init
from torch.utils.checkpoint import detach_variable, noop_context_fn
from .utils import safely_set_viewless_tensor_data
......@@ -31,15 +31,60 @@ _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False
def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None:
"""Sets the random number generator state of the current GPU.
_ALL_ACTIVE_RNG_STATES = {}
def get_all_rng_states() -> bool:
"""Returns all generator states used by `CudaRNGStatesTracker`."""
return _ALL_ACTIVE_RNG_STATES
def set_all_rng_states(states: List) -> None:
"""Updates all generator states used by `CudaRNGStatesTracker`."""
global _ALL_ACTIVE_RNG_STATES
_ALL_ACTIVE_RNG_STATES = states
def graph_safe_rng_available() -> bool:
"""Returns whether cuda graph safe RNG state manipulation is supported."""
return (hasattr(torch.cuda.CUDAGraph, "register_generator_state")
and hasattr(torch.Generator, "graphsafe_set_state")
and hasattr(torch.Generator, "graphsafe_get_state")
and hasattr(torch.Generator, "clone_state"))
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda",
clone: bool = False,
graph_safe: bool = True,
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU."""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
if clone:
# Reference to the cloned generator state
return default_generator.clone_state()
# Reference to the current generator state
return default_generator.graphsafe_get_state()
return default_generator.get_state()
def _set_cuda_rng_state(
new_state: torch.Tensor,
device: Union[int, str] = -1,
graph_safe = True,
) -> None:
"""Sets the random number generator state of the current GPU."""
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if device == -1:
device = torch.device("cuda")
elif isinstance(device, str):
......@@ -52,6 +97,9 @@ def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
default_generator.graphsafe_set_state(new_state)
return
default_generator.set_state(new_state)
_lazy_call(cb)
......@@ -206,7 +254,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
......@@ -271,13 +319,13 @@ class _CheckpointFunction(torch.autograd.Function):
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
......@@ -291,7 +339,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
......@@ -317,6 +365,7 @@ class _CheckpointFunction(torch.autograd.Function):
)
return (None, None, None, None, None, None) + grads
class _CheckpointFrame:
"""
Storage frame for forward RNG states and detached activations from the forward recompute.
......@@ -338,7 +387,7 @@ class _CheckpointFrame:
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = (
torch.get_rng_state(),
torch.cuda.get_rng_state(),
_get_cuda_rng_state(graph_safe=False),
)
if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(), )
......@@ -356,7 +405,7 @@ class _CheckpointFrame:
rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1])
_set_cuda_rng_state(rng_states[1], graph_safe=False)
if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2])
......@@ -604,6 +653,7 @@ def checkpoint(
return out
class CudaRNGStatesTracker:
"""
For model parallelism, multiple RNG states need to simultaneously exist in order
......@@ -664,13 +714,23 @@ class CudaRNGStatesTracker:
# Check that state is not already defined.
if name in self.states_:
raise Exception(f"cuda rng state {name} already exists")
if graph_safe_rng_available():
new_state = _get_cuda_rng_state(clone=True)
new_state.manual_seed(seed)
self.states_[name] = new_state
# Update global states.
set_all_rng_states(self.states_)
else:
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
orig_rng_state = _get_cuda_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
self.states_[name] = _get_cuda_rng_state(clone=True)
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
# Update global states.
set_all_rng_states(self.states_)
@contextmanager
def fork(self, name: str = "model-parallel-rng"):
......@@ -684,16 +744,17 @@ class CudaRNGStatesTracker:
# Check if we have added the state
if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added")
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Get the reference to current rng state.
orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# this is redundant with graph-safe API
if not graph_safe_rng_available():
self.states_[name] = _get_cuda_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
......
......@@ -16,6 +16,7 @@ from .fp8 import FP8GlobalStateManager
aten = torch.ops.aten
c10d = torch.ops.c10d
updated_fp8_params = {}
def _make_fp8_attr_property_funcs(name: str) -> Any:
......@@ -67,6 +68,31 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None
def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)
if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return
autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]
if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return
if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}
current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(
forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]
class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype"""
@staticmethod
......@@ -167,6 +193,7 @@ class _ToFloat8Func(torch.autograd.Function):
# Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
......@@ -307,8 +334,9 @@ class Float8Tensor(torch.Tensor):
), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: tex.DType = fp8_dtype
# Cached transpose
# Transposed version of `_data`.
self._transpose: Optional[Float8Tensor] = None
self._transpose_invalid: bool = True
# FP8 scale-inverse
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
......@@ -435,80 +463,51 @@ class Float8Tensor(torch.Tensor):
return _IdentityFunc.apply(self)
return super().expand_as(other)
def transpose(
def transpose_2d(
self,
dim0: int = 0,
dim1: int = 1,
*,
update_cache: str | bool = "reuse_only",
cache: bool = False,
noop_flag: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Swap tensor dimensions
For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
2D transpose with caching support.
Parameters
----------
dim0: int, default = 0
The first dimension to be transposed
dim1: int, default = 1
The second dimension to be transposed
update_cache: str or bool, default = "reuse_only"
Memoization behavior. Options are
"reuse_only"/`False` (reuse cached value if
available, otherwise calculate transpose without
caching), "force"/`True` (calculate transpose
and cache), "lazy" (reuse cached value if
available, otherwise calculate transpose and
cache if possible). Caching is only supported
for basic 2D transposes and the cache is reset
after any in-place operations.
cache: bool, default = `False`
Whether or not to cache the transpose.
noop_flag: Optional[torch.Tensor], default = `None`
Only used if argument `cache` is `True`, ignored otherwise.
A single element fp32 tensor with a value of 1.0 or 0.0
which is treated as a boolean. `1.0` forces recompute
and `0.0` executes a noop using the same kernel.
"""
assert self.dim() == 2, f"{self.dim()}-D transpose not supported."
# Case: no caching.
if not cache:
return tex.fp8_transpose(self._data, self._fp8_dtype)
# Case: reuse cache without calling a kernel.
if not self._transpose_invalid and noop_flag is None:
assert self._transpose is not None, "Tranpose cache is empty."
return self._transpose
# Allocate transpose if needed.
data_2d = self._data.reshape(-1, self._data.shape[-1])
if self._transpose is None:
shape = (data_2d.shape[1], data_2d.shape[0])
self._transpose = torch.empty(shape, dtype=torch.uint8, device=self._data.device)
# Case: recompute transpose and store cache.
if noop_flag is None:
tex.fp8_transpose_noalloc(data_2d, self._transpose, self._fp8_dtype)
else:
# Case: cuda graph capture.
tex.fp8_transpose_noalloc_noop(data_2d, self._transpose, noop_flag, self._fp8_dtype)
# Check caching mode
if not isinstance(update_cache, str):
update_cache = "force" if update_cache else "reuse_only"
if update_cache not in ("force", "reuse_only", "lazy"):
raise ValueError(
"Supported values for update_cache are "
'"force" (True), "reuse_only" (False), "lazy" '
f"(got {update_cache})"
)
# Handle non-2D transposes
if -self.dim() <= dim0 < 0:
dim0 += self.dim()
if -self.dim() <= dim1 < 0:
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache == "force":
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
)
return super().transpose(dim0, dim1)
# Clear cache if needed
if update_cache == "force":
self._transpose = None
# Compute transpose if needed
out = self._transpose
if out is None:
out = Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous(),
self._fp8_dtype,
),
)
# Update cache if needed
if update_cache in ("force", "lazy"):
self._transpose = out
return out
self._transpose_invalid = False
return self._transpose
@torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None:
......@@ -519,13 +518,11 @@ class Float8Tensor(torch.Tensor):
the tensor.
"""
if self._fp8_meta is None:
return
assert self._fp8_meta is not None, "FP8 meta tensors not found."
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward,
)
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index]
scale_inv.view(1).copy_(self._scale_inv.view(1))
self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0])
def to_dtype(self, dtype: torch.dtype) -> Float8Tensor:
"""Create `Float8Tensor` with given nominal dtype
......@@ -541,12 +538,11 @@ class Float8Tensor(torch.Tensor):
)
def _reset_caches(self) -> None:
"""Reset cached values
"""
Set transpose cache as invalid.
Should be called after any in-place operation.
"""
self._transpose = None
self._transpose_invalid = True
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......@@ -574,7 +570,7 @@ class Float8Tensor(torch.Tensor):
# Directly copy FP8 data if possible
if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone()
dst._scale_inv.copy_(src._scale_inv.detach())
if dst._fp8_meta is not None:
if src._fp8_meta is None:
src_min, src_max = src.from_float8().aminmax()
......@@ -600,7 +596,6 @@ class Float8Tensor(torch.Tensor):
dst.copy_(src.from_float8())
elif dst_is_fp8 and not src_is_fp8:
# Make sure input is in expected format
src = src.expand(dst.size())
src = src.to(
......@@ -619,7 +614,7 @@ class Float8Tensor(torch.Tensor):
fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv = scale.detach().view(1).reciprocal()
dst._scale_inv.copy_(scale.detach().reciprocal())
# Cast to FP8
if not dst._data.is_contiguous():
......@@ -633,6 +628,9 @@ class Float8Tensor(torch.Tensor):
dst._fp8_dtype,
)
# This branch is where the FP8 parameters are updated in-place during optimization.
# Handle forward amax reduction.
post_optimizer_step_fwd_amax_reduction(dst)
else:
# Invalid case
......@@ -641,6 +639,7 @@ class Float8Tensor(torch.Tensor):
# Nothing to return for in-place ops
if dst_is_fp8:
dst._reset_caches()
return None
# Slice op
......@@ -764,6 +763,7 @@ class Float8Tensor(torch.Tensor):
_fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index"))
_fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype"))
_transpose = property(**_make_fp8_attr_property_funcs("transpose"))
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
# Do not force the Float8Tensor type on the returned tensor
......
This diff is collapsed.
This diff is collapsed.
......@@ -8,8 +8,7 @@ import os
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Generator, Union, Optional, Tuple, Dict, Any, List
from functools import partial
from typing import Generator, Union, Optional, Tuple, List
from contextlib import contextmanager
import torch
......@@ -22,13 +21,11 @@ from ..fp8 import (
get_default_fp8_recipe,
get_fp8_te_dtype,
FP8GlobalStateManager,
amax_and_scale_update,
)
from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
get_distributed_world_size,
)
from ..cpp_extensions import (
fp8_cast_transpose_fused,
......@@ -44,7 +41,6 @@ _2X_ACC_WGRAD = True
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None
layers_atomic_ring_exchange = []
......@@ -64,49 +60,6 @@ def get_workspace() -> torch.Tensor:
)
return _cublas_workspace
@contextmanager
def _prepare_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> Generator[None, None, None]:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction
if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
# From previous iteration
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
else:
amax_and_scale_update(fp8_meta, False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if (fp8 and fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(fp8_meta["fp8_group"]) > 1):
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False)
def initialize_ub(
shape: list,
......@@ -300,31 +253,54 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_size = 1
self.sequence_parallel = False
self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
)
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
.. warning::
This changes the underlying amax memory location.
"""
if fwd is None:
fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
else:
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
for meta_key in fp8_meta_tensor_keys:
curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len:
continue
if length < curr_len:
self.fp8_meta[meta_key].amax_history = (
self.fp8_meta[meta_key].amax_history[: length].clone())
elif length > curr_len:
extra_rows = length - curr_len
self.fp8_meta[meta_key].amax_history = F.pad(
self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
)
# Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = (
self.fp8_meta[FP8GlobalStateManager.get_buffer_info()])
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history[0])
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history)
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
if self.fp8_meta_tensors_initialized:
# Handle changed amax history size.
curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0]
need_len = self.fp8_meta["recipe"].amax_history_len
if need_len < curr_len:
self.fp8_meta[fp8_meta_tensor_key].amax_history = (
self.fp8_meta[fp8_meta_tensor_key]
.amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
)
elif need_len > curr_len:
extra_rows = need_len - curr_len
self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
)
self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd)
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
......@@ -347,25 +323,45 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
device="cuda",
)
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False, True]: -> [input, weight, output]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False, True] * self.fp8_meta["num_gemms"]
).cuda()
else:
# [True, True]: -> [grad_output, grad_input]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes."""
self.set_meta_tensor(True)
self.set_meta_tensor(False)
self.fp8_meta_tensors_initialized = True
def get_fp8_meta_tensors(self) -> None:
"""Get scales and amaxes."""
fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
return None
fp8_meta_tensors = {fwd_key: [], bwd_key: []}
with torch.no_grad():
for key in (fwd_key, bwd_key):
fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
return fp8_meta_tensors
def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
"""Reset scales and amaxes."""
def reset(key):
if key in self.fp8_meta:
if fp8_meta_tensors is None:
self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
self.fp8_meta[key].scale_inv.copy_(
torch.ones_like(self.fp8_meta[key].scale_inv))
self.fp8_meta[key].amax_history.copy_(
torch.zeros_like(self.fp8_meta[key].amax_history))
else:
assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1])
self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2])
with torch.no_grad():
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None
......@@ -380,13 +376,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint()
state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint()
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str, list)):
if isinstance(v, (bool, int, float, str, tuple, list)):
extra[k] = v
state["extra_fp8_variables"] = extra
......@@ -414,11 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# Restore global FP8 amax buffer.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state.
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
......@@ -527,6 +516,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_group = tp_group
self.tp_group_initialized = True
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights."""
fp8_params = []
for param in self.parameters():
if isinstance(param, Float8Tensor) and param.requires_grad:
fp8_params.append(param)
if len(fp8_params) == 0:
return None
return fp8_params
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
......@@ -576,7 +575,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
......@@ -594,49 +592,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if is_first_microbatch is not None and not self.primary_weights_in_fp8:
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8."
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.new_fp8_context_id())
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.get_fp8_context_id())
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.fp8_meta, fp8_weights=self._get_fp8_params())
# Activation recomputation is used and this is the first forward phase.
if (
......@@ -653,18 +616,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return
if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
FP8GlobalStateManager.global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the
......
......@@ -14,7 +14,6 @@ from .. import cpp_extensions as tex
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
......@@ -65,6 +64,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
......@@ -89,6 +89,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool,
ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -98,7 +99,11 @@ class _LayerNormLinear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
......@@ -196,7 +201,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
......@@ -214,6 +218,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
tex.cast_to_fp8(
......@@ -295,6 +300,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_t_fp8,
ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
)
ctx.activation_dtype = activation_dtype
......@@ -321,6 +327,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
......@@ -344,9 +351,7 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
(
inputmat,
ln_weight,
......@@ -357,6 +362,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_t_fp8,
ln_out,
fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
......@@ -364,10 +370,13 @@ class _LayerNormLinear(torch.autograd.Function):
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False
......@@ -472,7 +481,7 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm(
weight_t_fp8._data,
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -686,6 +695,8 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -970,7 +981,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
......@@ -990,6 +1000,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
......@@ -1084,6 +1098,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
......@@ -1132,6 +1150,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
skip_fp8_weight_update,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
......@@ -1156,6 +1175,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag,
self.ub_name,
self.dummy_tensor,
)
out = fwd_fn(*args)
......
......@@ -13,7 +13,6 @@ from torch.nn import init
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
......@@ -94,6 +93,7 @@ class _LayerNormMLP(torch.autograd.Function):
use_fc2_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
......@@ -121,6 +121,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_overlap_rs: bool,
ub_overlap_ag: bool,
gemm_gelu_fusion: bool,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -131,7 +132,11 @@ class _LayerNormMLP(torch.autograd.Function):
assert_dim_for_fp8_exec(fc1_weight)
assert_dim_for_fp8_exec(fc2_weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
activation_func = _act_func(activation)[0]
......@@ -225,8 +230,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.reset_fp8_meta_scale_inv()
fc1_weight_fp8 = fc1_weight
fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
elif update_fp8_weights:
# Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor(
......@@ -250,6 +253,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
cast_out=fc1_weight_fp8._data,
transpose_out=fc1_weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
tex.fp8_cast_transpose_fused(
fc2_weight,
......@@ -258,6 +262,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
cast_out=fc2_weight_fp8._data,
transpose_out=fc2_weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
tex.cast_to_fp8(
......@@ -510,6 +515,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_t_fp8,
fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
)
ctx.activation_dtype = activation_dtype
ctx.activation = activation
......@@ -538,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_overlap_ag = ub_overlap_ag
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear
if ub_overlap_rs:
......@@ -563,9 +570,7 @@ class _LayerNormMLP(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
):
with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
(
inputmat,
ln_weight,
......@@ -582,6 +587,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
......@@ -592,11 +598,18 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8.
update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy"
if ctx.fp8 and fc1_weight_t_fp8 is None:
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache)
if ctx.fp8 and fc2_weight_t_fp8 is None:
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache)
if ctx.primary_weights_in_fp8:
fc1_weight_t_fp8 = fc1_weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
fc2_weight_t_fp8 = fc2_weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
fc1_weight_t_fp8 = fc1_weight_t_fp8._data
fc2_weight_t_fp8 = fc2_weight_t_fp8._data
activation_func = _act_func(ctx.activation)[1]
......@@ -673,7 +686,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8._data,
fc2_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
......@@ -826,7 +839,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj = None
# FC1 DGRAD: Unconditional
_ = tex.fp8_gemm(
fc1_weight_t_fp8._data,
fc1_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -1151,6 +1164,8 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -1389,7 +1404,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
......@@ -1414,6 +1428,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
......@@ -1473,7 +1491,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
@no_torch_dynamo()
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
......@@ -1497,6 +1517,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
......@@ -1535,6 +1559,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
skip_fp8_weight_update,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
......@@ -1562,6 +1587,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs,
self.ub_overlap_ag,
self.gemm_gelu_fusion,
self.dummy_tensor,
)
out = fwd_fn(*args)
......
......@@ -11,7 +11,6 @@ import transformer_engine_extensions as tex
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
......@@ -65,6 +64,7 @@ class _Linear(torch.autograd.Function):
bias: torch.Tensor,
use_bias: bool,
is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
......@@ -80,7 +80,8 @@ class _Linear(torch.autograd.Function):
primary_weights_in_fp8: bool,
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_name: str
ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -90,7 +91,12 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
......@@ -140,7 +146,6 @@ class _Linear(torch.autograd.Function):
# Weight is already in FP8
weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights:
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
......@@ -158,6 +163,7 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward,
cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
)
else:
cast_to_fp8(
......@@ -296,6 +302,7 @@ class _Linear(torch.autograd.Function):
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
......@@ -313,6 +320,7 @@ class _Linear(torch.autograd.Function):
ctx.ub_name = ub_name
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear
if ub_overlap_rs:
......@@ -330,9 +338,7 @@ class _Linear(torch.autograd.Function):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
with torch.cuda.nvtx.range("_Linear_backward"):
(
inputmat,
inputmat_t,
......@@ -340,6 +346,7 @@ class _Linear(torch.autograd.Function):
main_grad,
weight_t_fp8,
fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
......@@ -347,10 +354,14 @@ class _Linear(torch.autograd.Function):
weight.main_grad = main_grad
# Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None:
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
if ctx.ub_overlap_ag:
......@@ -361,6 +372,7 @@ class _Linear(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
(
grad_output,
grad_output_c,
......@@ -401,7 +413,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
dgrad, _ = fp8_gemm(
weight_t_fp8._data,
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
......@@ -542,6 +554,8 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -772,7 +786,6 @@ class Linear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta'))
......@@ -785,6 +798,10 @@ class Linear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
......@@ -858,6 +875,10 @@ class Linear(TransformerEngineBaseModule):
produced)
"""
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8."
......@@ -903,6 +924,7 @@ class Linear(TransformerEngineBaseModule):
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
skip_fp8_weight_update,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
......@@ -919,6 +941,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_name,
self.dummy_tensor,
)
out = linear_fn(*args)
......
......@@ -473,6 +473,15 @@ class TransformerLayer(torch.nn.Module):
if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group)
def reset_fp8_meta_tensors(self) -> None:
"""Set TP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "reset_fp8_meta_tensors"):
child.reset_fp8_meta_tensors()
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
......@@ -665,7 +674,8 @@ class TransformerLayer(torch.nn.Module):
# MLP.
mlp_outputs = self.layernorm_mlp(
hidden_states, is_first_microbatch=is_first_microbatch
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment