Unverified Commit 73f8d90f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] cuda graph support (#575)



* FP8 cuda graphs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix numerics
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* exclude torch compile from numerics tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More numerics fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm fusion from unfused path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
parent 1b20f2d6
...@@ -157,7 +157,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -157,7 +157,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
...@@ -238,13 +238,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -238,13 +238,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _stop_comm, 0));
}
if (A_scale_inverse.numel()) if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor]; A_scale_inverse = A_scale_inverse[A_fp8_tensor];
...@@ -350,11 +347,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -350,11 +347,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (int i = 0; i < _stream_compute.size(); i++) { for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
} }
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel()) if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor]; A_scale_inverse = A_scale_inverse[A_fp8_tensor];
...@@ -469,13 +467,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -469,13 +467,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
} }
} }
_ub_comm->sms = ori_sms; for (int i = 0; i < _stream_compute.size(); i++) {
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA( CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
_ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main); at::cuda::setCurrentCUDAStream(stream_main);
...@@ -506,7 +504,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -506,7 +504,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
} }
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(),
...@@ -805,14 +803,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -805,14 +803,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_scale_inverse.numel()) if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor]; B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
if (_aggregate2) {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
if (_aggregate2) {
const int num_steps = _tp_size / 2; const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
...@@ -877,21 +876,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -877,21 +876,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
} }
} }
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id =
(num_steps + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
} else { } else {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _tp_size; i++) { for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current // Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
...@@ -936,16 +923,19 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -936,16 +923,19 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
} }
} }
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id = (_tp_size + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
} }
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return D; return D;
} // split_overlap_ag } // split_overlap_ag
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include <transformer_engine/softmax.h> #include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <transformer_engine/cast_transpose_noop.h>
namespace transformer_engine { namespace transformer_engine {
......
...@@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input, ...@@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input,
); );
void fused_cast_transpose_noop(at::Tensor input,
at::Tensor noop,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
);
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -263,6 +274,17 @@ at::Tensor fp8_transpose(at::Tensor input, ...@@ -263,6 +274,17 @@ at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype transformer_engine::DType otype
); );
void fp8_transpose_noalloc(at::Tensor input,
at::Tensor output,
transformer_engine::DType otype
);
void fp8_transpose_noalloc_noop(at::Tensor input,
at::Tensor output,
at::Tensor noop,
transformer_engine::DType otype
);
/*************************************************************************************************** /***************************************************************************************************
* Activations * Activations
**************************************************************************************************/ **************************************************************************************************/
...@@ -559,14 +581,11 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads ...@@ -559,14 +581,11 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
* FP8 recipe * FP8 recipe
**************************************************************************************************/ **************************************************************************************************/
void fused_amax_and_scale_update(const at::Tensor &amax_history, void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
const at::Tensor &scale, std::vector<at::Tensor> amax_histories,
const at::Tensor &scale_inv, std::vector<at::Tensor> scales,
const at::Tensor &scale_inv_mask, std::vector<at::Tensor> scale_invs,
at::Tensor updated_amax_history, const std::string &amax_compute_algo,
at::Tensor updated_scale,
at::Tensor updated_scale_inv,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype, transformer_engine::DType fp8_dtype,
float margin); float margin);
......
...@@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop,
"Fused Cast + Transpose with noop option");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD"); "Fused Cast + Transpose + BGRAD");
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad,
...@@ -67,6 +69,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -67,6 +69,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_attn_bwd", &fused_attn_bwd, m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop,
"Transpose with FP8 I/O with noop option.");
m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output");
m.def("geglu", &geglu, "GeGLU with FP8 output"); m.def("geglu", &geglu, "GeGLU with FP8 output");
...@@ -82,9 +87,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -82,9 +87,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
m.def("fused_amax_and_scale_update", m.def("fused_amax_and_scale_update_after_reduction",
&fused_amax_and_scale_update, &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale"); "Update amax history and FP8 scale/scale_inv after reduction");
// fused apply rope // fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD"); m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD");
......
...@@ -11,24 +11,50 @@ ...@@ -11,24 +11,50 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
void fused_amax_and_scale_update(const at::Tensor &amax_history,
const at::Tensor &scale, void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
const at::Tensor &scale_inv, std::vector<at::Tensor> amax_histories,
const at::Tensor &scale_inv_mask, std::vector<at::Tensor> scales,
at::Tensor updated_amax_history, std::vector<at::Tensor> scale_invs,
at::Tensor updated_scale, const std::string &amax_compute_algo,
at::Tensor updated_scale_inv,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype, transformer_engine::DType fp8_dtype,
float margin) { float margin) {
nvte_delayed_scaling_recipe_amax_and_scale_update( using namespace transformer_engine;
makeTransformerEngineTensor(amax_history).data(), size_t num_tensors = amax_histories.size();
makeTransformerEngineTensor(scale).data(), std::vector<Tensor> t_amax_histories(num_tensors);
makeTransformerEngineTensor(scale_inv).data(), std::vector<Tensor> t_scales(num_tensors);
makeTransformerEngineTensor(scale_inv_mask).data(), std::vector<Tensor> t_scale_invs(num_tensors);
makeTransformerEngineTensor(updated_amax_history).data(), std::vector<NVTETensor> te_amax_histories(num_tensors);
makeTransformerEngineTensor(updated_scale).data(), std::vector<NVTETensor> te_scales(num_tensors);
makeTransformerEngineTensor(updated_scale_inv).data(), std::vector<NVTETensor> te_scale_invs(num_tensors);
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories[i].data.dptr = amax_histories[i].data_ptr();
auto amax_sizes = amax_histories[i].sizes().vec();
std::vector<size_t> amax_shape{amax_sizes.begin(), amax_sizes.end()};
t_amax_histories[i].data.shape = amax_shape;
t_amax_histories[i].data.dtype = DType::kFloat32;
t_scales[i].data.dptr = scales[i].data_ptr();
auto scale_sizes = scales[i].sizes().vec();
std::vector<size_t> scale_shape{scale_sizes.begin(), scale_sizes.end()};
t_scales[i].data.shape = scale_shape;
t_scales[i].data.dtype = DType::kFloat32;
t_scale_invs[i].data.dptr = scale_invs[i].data_ptr();
auto scale_inv_sizes = scale_invs[i].sizes().vec();
std::vector<size_t> scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()};
t_scale_invs[i].data.shape = scale_inv_shape;
t_scale_invs[i].data.dtype = DType::kFloat32;
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
te_scale_invs[i] = reinterpret_cast<NVTETensor>(&t_scale_invs[i]);
}
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(),
te_amax_histories,
te_scales,
te_scale_invs,
amax_compute_algo.c_str(), amax_compute_algo.c_str(),
static_cast<NVTEDType>(fp8_dtype), static_cast<NVTEDType>(fp8_dtype),
margin, margin,
......
...@@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input, ...@@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input,
} }
void fused_cast_transpose_noop(at::Tensor input,
at::Tensor noop,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(),
output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -319,3 +348,39 @@ at::Tensor fp8_transpose(at::Tensor input, ...@@ -319,3 +348,39 @@ at::Tensor fp8_transpose(at::Tensor input,
return output; return output;
} }
void fp8_transpose_noalloc(at::Tensor input,
at::Tensor output,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void fp8_transpose_noalloc_noop(at::Tensor input,
at::Tensor output,
at::Tensor noop,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose_with_noop(
input_cu.data(), noop_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
}
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
"""Methods needed for distributed training (DP/TP).""" """Methods needed for distributed training (DP/TP)."""
import warnings import warnings
from contextlib import contextmanager, AbstractContextManager, ContextDecorator from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, Union, Optional, Callable, Tuple from typing import Any, Dict, List, Union, Optional, Callable, Tuple
import torch import torch
from torch.cuda import _lazy_call from torch.cuda import _lazy_call, _lazy_init
from torch.utils.checkpoint import detach_variable, noop_context_fn from torch.utils.checkpoint import detach_variable, noop_context_fn
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
...@@ -31,15 +31,60 @@ _FP8_ACTIVATION_RECOMPUTE_ENABLED = False ...@@ -31,15 +31,60 @@ _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False _FP8_ACTIVATION_RECOMPUTE_PHASE = False
def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None: _ALL_ACTIVE_RNG_STATES = {}
"""Sets the random number generator state of the current GPU.
def get_all_rng_states() -> bool:
"""Returns all generator states used by `CudaRNGStatesTracker`."""
return _ALL_ACTIVE_RNG_STATES
def set_all_rng_states(states: List) -> None:
"""Updates all generator states used by `CudaRNGStatesTracker`."""
global _ALL_ACTIVE_RNG_STATES
_ALL_ACTIVE_RNG_STATES = states
def graph_safe_rng_available() -> bool:
"""Returns whether cuda graph safe RNG state manipulation is supported."""
return (hasattr(torch.cuda.CUDAGraph, "register_generator_state")
and hasattr(torch.Generator, "graphsafe_set_state")
and hasattr(torch.Generator, "graphsafe_get_state")
and hasattr(torch.Generator, "clone_state"))
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda",
clone: bool = False,
graph_safe: bool = True,
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU."""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
if clone:
# Reference to the cloned generator state
return default_generator.clone_state()
# Reference to the current generator state
return default_generator.graphsafe_get_state()
return default_generator.get_state()
def _set_cuda_rng_state(
new_state: torch.Tensor,
device: Union[int, str] = -1,
graph_safe = True,
) -> None:
"""Sets the random number generator state of the current GPU."""
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if device == -1: if device == -1:
device = torch.device("cuda") device = torch.device("cuda")
elif isinstance(device, str): elif isinstance(device, str):
...@@ -52,6 +97,9 @@ def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) - ...@@ -52,6 +97,9 @@ def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -
if idx is None: if idx is None:
idx = torch.cuda.current_device() idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx] default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
default_generator.graphsafe_set_state(new_state)
return
default_generator.set_state(new_state) default_generator.set_state(new_state)
_lazy_call(cb) _lazy_call(cb)
...@@ -206,7 +254,7 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -206,7 +254,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
...@@ -271,13 +319,13 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -271,13 +319,13 @@ class _CheckpointFunction(torch.autograd.Function):
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state) torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
...@@ -291,7 +339,7 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -291,7 +339,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state) torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
...@@ -317,6 +365,7 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -317,6 +365,7 @@ class _CheckpointFunction(torch.autograd.Function):
) )
return (None, None, None, None, None, None) + grads return (None, None, None, None, None, None) + grads
class _CheckpointFrame: class _CheckpointFrame:
""" """
Storage frame for forward RNG states and detached activations from the forward recompute. Storage frame for forward RNG states and detached activations from the forward recompute.
...@@ -338,7 +387,7 @@ class _CheckpointFrame: ...@@ -338,7 +387,7 @@ class _CheckpointFrame:
"""Cache fwd/bwd RNG states in the frame to restore later.""" """Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = ( rng_states = (
torch.get_rng_state(), torch.get_rng_state(),
torch.cuda.get_rng_state(), _get_cuda_rng_state(graph_safe=False),
) )
if self.get_rng_state_tracker is not None: if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(), ) rng_states += (self.get_rng_state_tracker().get_states(), )
...@@ -356,7 +405,7 @@ class _CheckpointFrame: ...@@ -356,7 +405,7 @@ class _CheckpointFrame:
rng_states = self.bwd_rng_states rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0]) torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1]) _set_cuda_rng_state(rng_states[1], graph_safe=False)
if self.get_rng_state_tracker is not None: if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2]) self.get_rng_state_tracker().set_states(rng_states[2])
...@@ -604,6 +653,7 @@ def checkpoint( ...@@ -604,6 +653,7 @@ def checkpoint(
return out return out
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
""" """
For model parallelism, multiple RNG states need to simultaneously exist in order For model parallelism, multiple RNG states need to simultaneously exist in order
...@@ -664,13 +714,23 @@ class CudaRNGStatesTracker: ...@@ -664,13 +714,23 @@ class CudaRNGStatesTracker:
# Check that state is not already defined. # Check that state is not already defined.
if name in self.states_: if name in self.states_:
raise Exception(f"cuda rng state {name} already exists") raise Exception(f"cuda rng state {name} already exists")
if graph_safe_rng_available():
new_state = _get_cuda_rng_state(clone=True)
new_state.manual_seed(seed)
self.states_[name] = new_state
# Update global states.
set_all_rng_states(self.states_)
else:
# Get the current rng state. # Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state() orig_rng_state = _get_cuda_rng_state()
# Set the new state and store it. # Set the new state and store it.
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state() self.states_[name] = _get_cuda_rng_state(clone=True)
# Reset rng state to what it was. # Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state) _set_cuda_rng_state(orig_rng_state)
# Update global states.
set_all_rng_states(self.states_)
@contextmanager @contextmanager
def fork(self, name: str = "model-parallel-rng"): def fork(self, name: str = "model-parallel-rng"):
...@@ -684,16 +744,17 @@ class CudaRNGStatesTracker: ...@@ -684,16 +744,17 @@ class CudaRNGStatesTracker:
# Check if we have added the state # Check if we have added the state
if name not in self.states_: if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added") raise Exception(f"cuda rng state {name} is not added")
# Store current rng state. # Get the reference to current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state() orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one # Set rng state to the desired one
_set_cuda_rng_state(self.states_[name]) _set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do. # Do the stuff we wanted to do.
try: try:
yield yield
finally: finally:
# Update the current rng state for later use. # this is redundant with graph-safe API
self.states_[name] = torch.cuda.get_rng_state() if not graph_safe_rng_available():
self.states_[name] = _get_cuda_rng_state()
# And set the state to the original state we started with. # And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state) _set_cuda_rng_state(orig_cuda_rng_state)
......
...@@ -16,6 +16,7 @@ from .fp8 import FP8GlobalStateManager ...@@ -16,6 +16,7 @@ from .fp8 import FP8GlobalStateManager
aten = torch.ops.aten aten = torch.ops.aten
c10d = torch.ops.c10d c10d = torch.ops.c10d
updated_fp8_params = {}
def _make_fp8_attr_property_funcs(name: str) -> Any: def _make_fp8_attr_property_funcs(name: str) -> Any:
...@@ -67,6 +68,31 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -67,6 +68,31 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None return grad, None
def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)
if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return
autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]
if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return
if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}
current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(
forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]
class _ToFloat8Func(torch.autograd.Function): class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Cast to FP8 from other dtype"""
@staticmethod @staticmethod
...@@ -167,6 +193,7 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -167,6 +193,7 @@ class _ToFloat8Func(torch.autograd.Function):
# Assume that we want gradients in full precision # Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None return grad, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function): class _IdentityFunc(torch.autograd.Function):
"""Identity function """Identity function
...@@ -307,8 +334,9 @@ class Float8Tensor(torch.Tensor): ...@@ -307,8 +334,9 @@ class Float8Tensor(torch.Tensor):
), f"Unsupported fp8_dtype {fp8_dtype}." ), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: tex.DType = fp8_dtype self._fp8_dtype: tex.DType = fp8_dtype
# Cached transpose # Transposed version of `_data`.
self._transpose: Optional[Float8Tensor] = None self._transpose: Optional[Float8Tensor] = None
self._transpose_invalid: bool = True
# FP8 scale-inverse # FP8 scale-inverse
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
...@@ -435,80 +463,51 @@ class Float8Tensor(torch.Tensor): ...@@ -435,80 +463,51 @@ class Float8Tensor(torch.Tensor):
return _IdentityFunc.apply(self) return _IdentityFunc.apply(self)
return super().expand_as(other) return super().expand_as(other)
def transpose( def transpose_2d(
self, self,
dim0: int = 0,
dim1: int = 1,
*, *,
update_cache: str | bool = "reuse_only", cache: bool = False,
noop_flag: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Swap tensor dimensions 2D transpose with caching support.
For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
Parameters Parameters
---------- ----------
dim0: int, default = 0 cache: bool, default = `False`
The first dimension to be transposed Whether or not to cache the transpose.
dim1: int, default = 1 noop_flag: Optional[torch.Tensor], default = `None`
The second dimension to be transposed Only used if argument `cache` is `True`, ignored otherwise.
update_cache: str or bool, default = "reuse_only" A single element fp32 tensor with a value of 1.0 or 0.0
Memoization behavior. Options are which is treated as a boolean. `1.0` forces recompute
"reuse_only"/`False` (reuse cached value if and `0.0` executes a noop using the same kernel.
available, otherwise calculate transpose without
caching), "force"/`True` (calculate transpose
and cache), "lazy" (reuse cached value if
available, otherwise calculate transpose and
cache if possible). Caching is only supported
for basic 2D transposes and the cache is reset
after any in-place operations.
""" """
assert self.dim() == 2, f"{self.dim()}-D transpose not supported."
# Case: no caching.
if not cache:
return tex.fp8_transpose(self._data, self._fp8_dtype)
# Case: reuse cache without calling a kernel.
if not self._transpose_invalid and noop_flag is None:
assert self._transpose is not None, "Tranpose cache is empty."
return self._transpose
# Allocate transpose if needed.
data_2d = self._data.reshape(-1, self._data.shape[-1])
if self._transpose is None:
shape = (data_2d.shape[1], data_2d.shape[0])
self._transpose = torch.empty(shape, dtype=torch.uint8, device=self._data.device)
# Case: recompute transpose and store cache.
if noop_flag is None:
tex.fp8_transpose_noalloc(data_2d, self._transpose, self._fp8_dtype)
else:
# Case: cuda graph capture.
tex.fp8_transpose_noalloc_noop(data_2d, self._transpose, noop_flag, self._fp8_dtype)
# Check caching mode self._transpose_invalid = False
if not isinstance(update_cache, str): return self._transpose
update_cache = "force" if update_cache else "reuse_only"
if update_cache not in ("force", "reuse_only", "lazy"):
raise ValueError(
"Supported values for update_cache are "
'"force" (True), "reuse_only" (False), "lazy" '
f"(got {update_cache})"
)
# Handle non-2D transposes
if -self.dim() <= dim0 < 0:
dim0 += self.dim()
if -self.dim() <= dim1 < 0:
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache == "force":
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
)
return super().transpose(dim0, dim1)
# Clear cache if needed
if update_cache == "force":
self._transpose = None
# Compute transpose if needed
out = self._transpose
if out is None:
out = Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous(),
self._fp8_dtype,
),
)
# Update cache if needed
if update_cache in ("force", "lazy"):
self._transpose = out
return out
@torch.no_grad() @torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None: def reset_fp8_meta_scale_inv(self) -> None:
...@@ -519,13 +518,11 @@ class Float8Tensor(torch.Tensor): ...@@ -519,13 +518,11 @@ class Float8Tensor(torch.Tensor):
the tensor. the tensor.
""" """
if self._fp8_meta is None: assert self._fp8_meta is not None, "FP8 meta tensors not found."
return
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward, forward=self._fp8_meta_forward,
) )
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0])
scale_inv.view(1).copy_(self._scale_inv.view(1))
def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: def to_dtype(self, dtype: torch.dtype) -> Float8Tensor:
"""Create `Float8Tensor` with given nominal dtype """Create `Float8Tensor` with given nominal dtype
...@@ -541,12 +538,11 @@ class Float8Tensor(torch.Tensor): ...@@ -541,12 +538,11 @@ class Float8Tensor(torch.Tensor):
) )
def _reset_caches(self) -> None: def _reset_caches(self) -> None:
"""Reset cached values """
Set transpose cache as invalid.
Should be called after any in-place operation. Should be called after any in-place operation.
""" """
self._transpose = None self._transpose_invalid = True
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
...@@ -574,7 +570,7 @@ class Float8Tensor(torch.Tensor): ...@@ -574,7 +570,7 @@ class Float8Tensor(torch.Tensor):
# Directly copy FP8 data if possible # Directly copy FP8 data if possible
if dst._fp8_dtype == src._fp8_dtype: if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data) dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone() dst._scale_inv.copy_(src._scale_inv.detach())
if dst._fp8_meta is not None: if dst._fp8_meta is not None:
if src._fp8_meta is None: if src._fp8_meta is None:
src_min, src_max = src.from_float8().aminmax() src_min, src_max = src.from_float8().aminmax()
...@@ -600,7 +596,6 @@ class Float8Tensor(torch.Tensor): ...@@ -600,7 +596,6 @@ class Float8Tensor(torch.Tensor):
dst.copy_(src.from_float8()) dst.copy_(src.from_float8())
elif dst_is_fp8 and not src_is_fp8: elif dst_is_fp8 and not src_is_fp8:
# Make sure input is in expected format # Make sure input is in expected format
src = src.expand(dst.size()) src = src.expand(dst.size())
src = src.to( src = src.to(
...@@ -619,7 +614,7 @@ class Float8Tensor(torch.Tensor): ...@@ -619,7 +614,7 @@ class Float8Tensor(torch.Tensor):
fp8_meta_index = dst._fp8_meta_index fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv = scale.detach().view(1).reciprocal() dst._scale_inv.copy_(scale.detach().reciprocal())
# Cast to FP8 # Cast to FP8
if not dst._data.is_contiguous(): if not dst._data.is_contiguous():
...@@ -633,6 +628,9 @@ class Float8Tensor(torch.Tensor): ...@@ -633,6 +628,9 @@ class Float8Tensor(torch.Tensor):
dst._fp8_dtype, dst._fp8_dtype,
) )
# This branch is where the FP8 parameters are updated in-place during optimization.
# Handle forward amax reduction.
post_optimizer_step_fwd_amax_reduction(dst)
else: else:
# Invalid case # Invalid case
...@@ -641,6 +639,7 @@ class Float8Tensor(torch.Tensor): ...@@ -641,6 +639,7 @@ class Float8Tensor(torch.Tensor):
# Nothing to return for in-place ops # Nothing to return for in-place ops
if dst_is_fp8: if dst_is_fp8:
dst._reset_caches() dst._reset_caches()
return None return None
# Slice op # Slice op
...@@ -764,6 +763,7 @@ class Float8Tensor(torch.Tensor): ...@@ -764,6 +763,7 @@ class Float8Tensor(torch.Tensor):
_fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index"))
_fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype"))
_transpose = property(**_make_fp8_attr_property_funcs("transpose")) _transpose = property(**_make_fp8_attr_property_funcs("transpose"))
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
# Do not force the Float8Tensor type on the returned tensor # Do not force the Float8Tensor type on the returned tensor
......
...@@ -51,6 +51,17 @@ def get_fp8_te_dtype( ...@@ -51,6 +51,17 @@ def get_fp8_te_dtype(
return tex.DType.kFloat8E5M2 return tex.DType.kFloat8E5M2
def get_fp8_max(
fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
"""Get max representible FP8 value."""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return Format.E4M3.value.max_fwd
return Format.E5M2.value.max_fwd
class FP8GlobalStateManager: class FP8GlobalStateManager:
"""Class to keep track of and manipulate the global """Class to keep track of and manipulate the global
FP8 state at different stages of execution. FP8 state at different stages of execution.
...@@ -61,20 +72,21 @@ class FP8GlobalStateManager: ...@@ -61,20 +72,21 @@ class FP8GlobalStateManager:
FP8_DISTRIBUTED_GROUP = None FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False FP8_PARAMETERS = False
IS_FIRST_FP8_MODULE = False IS_FIRST_FP8_MODULE = False
FP8_AUTOCAST_COUNTER = 0 FP8_GRAPH_CAPTURING = False
FP8_CURRENT_CONTEXT_ID = 0
FP8_AUTOCAST_DEPTH = 0 FP8_AUTOCAST_DEPTH = 0
global_fp8_buffer = {} global_amax_buffer = {}
global_amax_history_buffer = {}
global_scale_buffer = {}
global_scale_inv_buffer = {}
fp8_tensors_recompute_buffer = [] fp8_tensors_recompute_buffer = []
amax_forward_global_reduce_func = None
buffer_delete_key_fwd = None
buffer_delete_key_bwd = None
amax_reduce_handle_fwd = None
fp8_available = None fp8_available = None
reason_for_no_fp8 = "" reason_for_no_fp8 = ""
dp_amax_reduce_interval = None multi_grad_hook_tensors = []
dp_amax_reduce_forward_idx = 0 bwd_amax_update_hook_registered = False
dp_amax_reduce_backward_idx = 0 autocast_arguments = {}
autocast_to_fp8_params = {}
fp8_param_to_autocast = {}
skip_fp8_weight_update_tensor = None
@classmethod @classmethod
def reset(cls) -> None: def reset(cls) -> None:
...@@ -83,21 +95,35 @@ class FP8GlobalStateManager: ...@@ -83,21 +95,35 @@ class FP8GlobalStateManager:
cls.FP8_CALIBRATION = False cls.FP8_CALIBRATION = False
cls.FP8_RECIPE = None cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None cls.FP8_DISTRIBUTED_GROUP = None
cls.FP8_PARAMETERS = False
cls.IS_FIRST_FP8_MODULE = False cls.IS_FIRST_FP8_MODULE = False
cls.FP8_AUTOCAST_COUNTER = 0 cls.FP8_GRAPH_CAPTURING = False
cls.FP8_CURRENT_CONTEXT_ID = 0
cls.FP8_AUTOCAST_DEPTH = 0 cls.FP8_AUTOCAST_DEPTH = 0
cls.global_fp8_buffer = {} cls.global_amax_buffer = {}
cls.global_amax_history_buffer = {}
cls.global_scale_buffer = {}
cls.global_scale_inv_buffer = {}
cls.fp8_tensors_recompute_buffer = [] cls.fp8_tensors_recompute_buffer = []
cls.amax_forward_global_reduce_func = None
cls.buffer_delete_key_fwd = None
cls.buffer_delete_key_bwd = None
cls.amax_reduce_handle_fwd = None
cls.fp8_available = None cls.fp8_available = None
cls.reason_for_no_fp8 = "" cls.reason_for_no_fp8 = ""
cls.dp_amax_reduce_interval = None cls.multi_grad_hook_tensors = []
cls.dp_amax_reduce_forward_idx = 0 cls.bwd_amax_update_hook_registered = False
cls.dp_amax_reduce_backward_idx = 0 cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
cls.skip_fp8_weight_update_tensor = None
@classmethod
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
"""`skip_fp8_weight_update_tensor` inplace setter."""
if cls.skip_fp8_weight_update_tensor is None:
cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
cls.skip_fp8_weight_update_tensor.fill_(skip)
@classmethod
def get_skip_fp8_weight_update_tensor(cls) -> None:
"""`skip_fp8_weight_update_tensor` getter."""
return cls.skip_fp8_weight_update_tensor
@classmethod @classmethod
def is_fp8_available(cls) -> Tuple[bool, str]: def is_fp8_available(cls) -> Tuple[bool, str]:
...@@ -106,44 +132,6 @@ class FP8GlobalStateManager: ...@@ -106,44 +132,6 @@ class FP8GlobalStateManager:
cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support()
return cls.fp8_available, cls.reason_for_no_fp8 return cls.fp8_available, cls.reason_for_no_fp8
@classmethod
def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]:
"""Returns global fp8 state variables."""
# Convert attributes to dictionary to make future proof against
# changes in global state variables in order to make setting the
# checkpoint backwards compatible.
global_fp8_state = {}
global_fp8_state["FP8_AUTOCAST_COUNTER"] = cls.FP8_AUTOCAST_COUNTER
global_fp8_state["FP8_CURRENT_CONTEXT_ID"] = cls.FP8_CURRENT_CONTEXT_ID
global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH
global_fp8_state["buffer_delete_key_fwd"] = cls.buffer_delete_key_fwd
global_fp8_state["buffer_delete_key_bwd"] = cls.buffer_delete_key_bwd
global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval
global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx
global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx
return global_fp8_state
@classmethod
def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> None:
"""Sets global fp8 state variables."""
for k, v in state.items():
if hasattr(cls, k):
setattr(cls, k, v)
@classmethod
def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 amax buffer."""
return cls.global_fp8_buffer
@classmethod
def set_global_fp8_buffer_checkpoint(cls, buffer: Dict[str, List[torch.Tensor]]) -> None:
"""Sets global fp8 amax buffer."""
# Map all tensors back to GPU.
for k, v in buffer.items():
buffer[k] = [tensor.cuda() for tensor in v]
cls.global_fp8_buffer = buffer
@staticmethod @staticmethod
def get_meta_tensor_key(forward: bool = True) -> str: def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`.""" """Returns scaling key in `fp8_meta`."""
...@@ -152,121 +140,102 @@ class FP8GlobalStateManager: ...@@ -152,121 +140,102 @@ class FP8GlobalStateManager:
return "scaling_bwd" return "scaling_bwd"
@staticmethod @staticmethod
def get_buffer_position_key(forward: bool = True) -> str: def get_fwd_bwd_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`.""" """Convert bool `forward` to string."""
if forward: return "forward" if forward else "backward"
return "global_fp8_buffer_pos_fwd"
return "global_fp8_buffer_pos_bwd"
@staticmethod
def get_autocast_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "autocast_id_fwd"
return "autocast_id_bwd"
@staticmethod
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"
@classmethod @classmethod
def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]: def get_buffer_info(cls) -> str:
"""Return AMAX reduction wait handle of forward prop.""" """
return cls.amax_reduce_handle_fwd Returns a key for `fp8_meta` that stores the module's index
in the global buffers along with autocast information.
"""
return "buffer_index_and_autocast_key"
@classmethod @classmethod
def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None: def get_key_in_buffer(
"""Sets up the function to call during autocast exit.""" cls,
cls.amax_forward_global_reduce_func = f forward: bool,
fp8_weights: bool,
fp8_recipe: DelayedScaling,
fp8_group: dist_group_type,
) -> str:
"""Returns a key into the global FP8 buffers."""
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
fwd_bwd_key = cls.get_fwd_bwd_key(forward)
return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}"
@classmethod @classmethod
def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]:
"""Append 1D tensor `amax` to global buffer.""" """Splits buffer key into relevant parts."""
buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) forward, fp8_weights, autocast_key = key.split("_", 2)
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) forward = forward == "forward"
buffer_position_key = cls.get_buffer_position_key(forward=forward) fp8_weights = fp8_weights == "True"
return forward, fp8_weights, autocast_key
if buffer_key not in cls.global_fp8_buffer:
cls.global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else:
cls.global_fp8_buffer[buffer_key].append(
fp8_meta[fp8_meta_tensor_key].amax_history[0]
)
if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(cls.global_fp8_buffer[buffer_key]) - 1
# Catch incorrect fp8_autocast usage.
assert fp8_meta[buffer_position_key] == len(cls.global_fp8_buffer[buffer_key]) - 1, \
"Same module is being invoked more than once inside an `fp8_autocast` " \
"region when using FP8 with amax reduction. This behavior is currently" \
" unsupported. For more details and correct usage, please see " \
"https://github.com/NVIDIA/TransformerEngine/pull/93."
@classmethod @classmethod
def copy_amax_from_global_buffer( def add_fp8_tensors_to_global_buffer(
cls, fp8_meta: Dict[str, Any], forward: bool = True cls,
fp8_meta: Dict[str, Any],
fp8_weights: Optional[List[torch.Tensor]] = None,
) -> None: ) -> None:
"""Populate current amax with the correct location from buffer.""" """
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) The amax reduction process happens completely outside the FP8 modules.
buffer_position_key = cls.get_buffer_position_key(forward=forward) To participate in the reduction, the only role played by a module is
if buffer_position_key not in fp8_meta: to call this function in order to append it's FP8 tensor into a global
return buffer. There are 5 global buffers maintained, one each for amax, amax
history, scale, scale-inverse, and non-weight-mask. Each buffer has
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
assert amax_buffer_key in cls.global_fp8_buffer, "TE internal error." to indicate the type of FP8 tensor, since the forward and backward
reductions happen separately.
fp8_meta[fp8_meta_tensor_key].amax_history[0] = cls.global_fp8_buffer[amax_buffer_key][
fp8_meta[buffer_position_key] Note: For CG capture, this method is called from the graphed
] wrapper. For non CG case, it's called from within the module.
"""
@classmethod # Every module must call this function exactly once since
def set_amax_buffer_key_deletion( # the amax tensors are static. Ensures that compatibility
cls, fp8_meta: Dict[str, Any], forward: bool = True # with non-graphed modules is maintained.
) -> None: index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors.
"""Delete this amax key from global buffer during autocast end.""" if index_in_buffer in fp8_meta:
if cls.get_autocast_key(forward=forward) not in fp8_meta:
return return
if forward:
cls.buffer_delete_key_fwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
else:
cls.buffer_delete_key_bwd = cls.get_amax_buffer_key(fp8_meta, forward=forward)
@classmethod fp8_meta[index_in_buffer] = []
def delete_key_from_amax_buffer(cls, forward: bool = True) -> None: for forward in (True, False):
"""Delete the key from global amax buffer.""" # This algorithm creates a two-way map with `autocast_to_fp8_params` and
if forward: # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
if ( # in an autocasted region and cross reference them in `float8_tensor.py`
cls.buffer_delete_key_fwd is not None # to perform the forward amax reduction.
and cls.buffer_delete_key_fwd in cls.global_fp8_buffer if forward and fp8_weights is not None:
): autocast_key = cls.get_unique_autocast_key(
del cls.global_fp8_buffer[cls.buffer_delete_key_fwd] fp8_meta["recipe"], fp8_meta["fp8_group"])
fp8_weight_set = {id(w._data) for w in fp8_weights}
if autocast_key not in cls.autocast_to_fp8_params:
cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set
else: else:
if ( cls.autocast_to_fp8_params[autocast_key] = (
cls.buffer_delete_key_bwd is not None cls.autocast_to_fp8_params[autocast_key].union(fp8_weight_set))
and cls.buffer_delete_key_bwd in cls.global_fp8_buffer # Identify correct autocast key for a given param.
): for w in fp8_weight_set:
del cls.global_fp8_buffer[cls.buffer_delete_key_bwd] cls.fp8_param_to_autocast[w] = autocast_key
@classmethod key = cls.get_key_in_buffer(
def get_fp8_context_id(cls) -> int: forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"])
"""Returns an ID for the current FP8 context.""" fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
return cls.FP8_CURRENT_CONTEXT_ID
@classmethod
def set_fp8_context_id(cls, ctx_id: int) -> None:
"""Sets the current FP8 context."""
cls.FP8_CURRENT_CONTEXT_ID = ctx_id
@classmethod if key not in cls.global_amax_buffer:
def new_fp8_context_id(cls) -> int: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
"""Returns global autocast counter as a proxy to be used cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history]
as the autocast ID for FP8 modules. cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale]
""" cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv]
return cls.FP8_AUTOCAST_COUNTER else:
cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
cls.global_amax_history_buffer[key].append(
fp8_meta[fp8_meta_tensor_key].amax_history)
cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv)
fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
fp8_meta[index_in_buffer].append(key)
@classmethod @classmethod
def is_fp8_enabled(cls) -> bool: def is_fp8_enabled(cls) -> bool:
...@@ -283,6 +252,11 @@ class FP8GlobalStateManager: ...@@ -283,6 +252,11 @@ class FP8GlobalStateManager:
"""Should the parameters be stored as FP8""" """Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS return cls.FP8_PARAMETERS
@classmethod
def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()
@classmethod @classmethod
def is_first_fp8_module(cls): def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple """Returns `True` only the first time when called multiple
...@@ -310,7 +284,8 @@ class FP8GlobalStateManager: ...@@ -310,7 +284,8 @@ class FP8GlobalStateManager:
cls.FP8_CALIBRATION, cls.FP8_CALIBRATION,
cls.FP8_RECIPE, cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP, cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE) cls.IS_FIRST_FP8_MODULE,
cls.FP8_GRAPH_CAPTURING)
@classmethod @classmethod
def set_fp8_autocast_state( def set_fp8_autocast_state(
...@@ -322,80 +297,100 @@ class FP8GlobalStateManager: ...@@ -322,80 +297,100 @@ class FP8GlobalStateManager:
cls.FP8_CALIBRATION, cls.FP8_CALIBRATION,
cls.FP8_RECIPE, cls.FP8_RECIPE,
cls.FP8_DISTRIBUTED_GROUP, cls.FP8_DISTRIBUTED_GROUP,
cls.IS_FIRST_FP8_MODULE) = fp8_state cls.IS_FIRST_FP8_MODULE,
cls.FP8_GRAPH_CAPTURING) = fp8_state
@staticmethod @staticmethod
def reduce_tensor_across_group_op_max( def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type, async_op: bool tensor: torch.Tensor, group: dist_group_type
) -> None: ) -> None:
"""Reduce tensor across given group.""" """Reduce tensor across given group."""
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
wait_handle = torch.distributed.all_reduce( torch.distributed.all_reduce(
tensor, tensor,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=group, group=group,
async_op=async_op, async_op=False,
) )
return wait_handle
return None
@classmethod @classmethod
def global_amax_reduction( def reduce_and_update_fp8_tensors(
cls, cls,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
forward: bool = True, forward: bool = True,
fp8_weights: bool = False,
) -> None: ) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer.""" """Concatenate, reduce, and split amaxes in the global buffer."""
amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction.
# Key already deleted. fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key)
if amax_buffer_key not in cls.global_fp8_buffer: if fwd_update != forward:
return None continue
# Only skip a forward update when `fp8_weights` is explicitly set to `True`
# Reduce AMAX in DP-domain at an interval. # (inside optimizer) and the current key is not an `fp8_weight_update` key.
# `NVTE_DP_AMAX_REDUCE_INTERVAL` should be set as an integer value larger than 0. If # For other cases, we need to reduce because of activation tensors.
# `NVTE_DP_AMAX_REDUCE_INTERVAL` is set to 0, AMAX is reduced only in TP domain. # TODO(ksivaman) consider separate weight and activation fp8_tensors.
if cls.dp_amax_reduce_interval is None: if fwd_update and fp8_weights and not fp8_weights_update:
cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) continue
if len(amax_buffer) == 0:
if cls.dp_amax_reduce_interval == 0: continue
tp_amax_reduce = True
else: # Retrieve autocast specific args and concat amaxes.
tp_amax_reduce = False recipe, group = cls.autocast_arguments[autocast_key]
if forward: contiguous_amax = torch.cat(amax_buffer)
if cls.dp_amax_reduce_forward_idx == 0:
reduce_group = fp8_meta["fp8_group"] # Reduction.
else: if (recipe.reduce_amax
tp_amax_reduce = True and torch.distributed.is_initialized()
cls.dp_amax_reduce_forward_idx = ( and torch.distributed.get_world_size(group=group) > 1):
(cls.dp_amax_reduce_forward_idx + 1) % cls.dp_amax_reduce_interval) cls.reduce_tensor_across_group_op_max(contiguous_amax, group)
else:
if cls.dp_amax_reduce_backward_idx == 0: # Amax and scale update.
reduce_group = fp8_meta["fp8_group"] unfused_update = (bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0")))
or callable(recipe.amax_compute_algo)
or callable(recipe.scaling_factor_compute_algo))
if not unfused_update:
tex.fused_amax_and_scale_update_after_reduction(
contiguous_amax,
cls.global_amax_history_buffer[buffer_key],
cls.global_scale_buffer[buffer_key],
cls.global_scale_inv_buffer[buffer_key],
recipe.amax_compute_algo,
get_fp8_te_dtype(recipe, forward),
recipe.margin,
)
else: else:
tp_amax_reduce = True split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer])
cls.dp_amax_reduce_backward_idx = (
(cls.dp_amax_reduce_backward_idx + 1) % cls.dp_amax_reduce_interval)
if tp_amax_reduce: for amax_history, scale, scale_inv in zip(
if tp_size > 1: cls.global_amax_history_buffer[buffer_key],
reduce_group = tp_group cls.global_scale_buffer[buffer_key],
else: cls.global_scale_inv_buffer[buffer_key],
return None ):
_amax_and_scale_update(
amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe)
chunk_sizes = [x.numel() for x in cls.global_fp8_buffer[amax_buffer_key]] @classmethod
contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key]) def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor):
"""Add tensor to list for multi grad hook."""
cls.multi_grad_hook_tensors.append(tensor)
wait_handle = cls.reduce_tensor_across_group_op_max( @classmethod
contiguous_amax, def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument
reduce_group, """Executes at the end of backward pass."""
fp8_meta["async_amax_reduction"], cls.reduce_and_update_fp8_tensors(forward=False)
)
cls.global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) @classmethod
return wait_handle def get_unique_autocast_key(
cls,
recipe: Optional[DelayedScaling] = None,
group: Optional[dist_group_type] = None,
):
"""
For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
Safely using `hash` as we never cross checkpoint boundaries.
"""
return f"{str(recipe)}:{hash(group)}"
@classmethod @classmethod
def fp8_autocast_enter( def fp8_autocast_enter(
...@@ -404,21 +399,29 @@ class FP8GlobalStateManager: ...@@ -404,21 +399,29 @@ class FP8GlobalStateManager:
calibrating: bool = False, calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None: ) -> None:
"""Set state and tracking variables for entry into FP8 region.""" """Set state and tracking variables for entry into FP8 region."""
if cls.FP8_AUTOCAST_DEPTH == 0:
if callable(cls.amax_forward_global_reduce_func): fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
cls.delete_key_from_amax_buffer(forward=True) cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group)
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0:
# This hook does not fire for graphed modules.
torch.autograd.graph.register_multi_grad_hook(
tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction)
cls.bwd_amax_update_hook_registered = True
cls.FP8_ENABLED = enabled cls.FP8_ENABLED = enabled
cls.FP8_CALIBRATION = calibrating cls.FP8_CALIBRATION = calibrating
cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe cls.FP8_RECIPE = fp8_recipe
cls.FP8_DISTRIBUTED_GROUP = fp8_group cls.FP8_DISTRIBUTED_GROUP = fp8_group
cls.FP8_GRAPH_CAPTURING = _graph
if cls.FP8_AUTOCAST_DEPTH == 0: if cls.FP8_AUTOCAST_DEPTH == 0:
cls.IS_FIRST_FP8_MODULE = True cls.IS_FIRST_FP8_MODULE = True
cls.FP8_AUTOCAST_COUNTER += 1
cls.FP8_AUTOCAST_DEPTH += 1 cls.FP8_AUTOCAST_DEPTH += 1
if enabled: if enabled:
...@@ -426,9 +429,14 @@ class FP8GlobalStateManager: ...@@ -426,9 +429,14 @@ class FP8GlobalStateManager:
assert fp8_available, reason_for_no_fp8 assert fp8_available, reason_for_no_fp8
@classmethod @classmethod
def fp8_autocast_exit(cls): def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
"""Set state and tracking variables for exit from FP8 region.""" """Set state and tracking variables for exit from FP8 region."""
cls.FP8_AUTOCAST_DEPTH -= 1 cls.FP8_AUTOCAST_DEPTH -= 1
# Reduce only the non-FP8 weight modules here.
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False)
@classmethod @classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
...@@ -525,6 +533,7 @@ def fp8_autocast( ...@@ -525,6 +533,7 @@ def fp8_autocast(
calibrating: bool = False, calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
_graph: bool = False,
) -> None: ) -> None:
""" """
Context manager for FP8 usage. Context manager for FP8 usage.
...@@ -568,23 +577,25 @@ def fp8_autocast( ...@@ -568,23 +577,25 @@ def fp8_autocast(
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating, calibrating=calibrating,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
fp8_group=fp8_group) fp8_group=fp8_group,
_graph=_graph)
yield yield
finally: finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
FP8GlobalStateManager.fp8_autocast_exit() FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)
def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero.""" """Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1: if amax_history.shape[0] > 1:
amax_history = torch.roll(amax_history, -1, 0) new_amax_history = torch.roll(amax_history, -1, 0)
amax_history.copy_(new_amax_history)
amax_history[0].fill_(0.0) amax_history[0].fill_(0.0)
return amax_history return amax_history
@torch.jit.script @torch.jit.script
def _default_get_amax( def _default_get_amax_and_update_history(
amax_history: torch.Tensor, amax_history: torch.Tensor,
amax_compute_algo: str, amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -609,63 +620,23 @@ def _default_sf_compute( ...@@ -609,63 +620,23 @@ def _default_sf_compute(
sf = (fp8_max / amax) / (2 ** margin) sf = (fp8_max / amax) / (2 ** margin)
sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale)
return sf scale.copy_(sf)
return scale
@jit_fuser
def _compute_scaling_factor_inverse(
scale: torch.Tensor,
scale_inv: torch.Tensor,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> torch.Tensor:
"""Compute inverse of scaling factor."""
if update_weight_scale_inv:
return 1.0 / scale
return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
def _fused_amax_and_scale_update( def _compute_amax_and_update_history(
amax_history: torch.Tensor, amax_history: torch.Tensor,
scale: torch.Tensor, amax_compute_algo: Union[Callable, str],
scale_inv: torch.Tensor,
fp8_dtype: tex.DType,
margin: int,
amax_compute_algo: str,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Update amax history and FP8 scaling factors"""
if update_weight_scale_inv:
non_weight_mask = torch.Tensor()
tex.fused_amax_and_scale_update(
amax_history,
scale,
scale_inv,
non_weight_mask,
amax_history,
scale,
scale_inv,
amax_compute_algo,
fp8_dtype,
margin,
)
return amax_history, scale, scale_inv
def _compute_amax(
amax_history: torch.Tensor,
recipe: DelayedScaling,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Obtain the amax from the history.""" """Obtain the amax from the history."""
if callable(recipe.amax_compute_algo): if callable(amax_compute_algo):
amax = recipe.amax_compute_algo(amax_history) amax = amax_compute_algo(amax_history)
amax_history = _update_amax_history(amax_history) amax_history = _update_amax_history(amax_history)
return amax_history, amax return amax_history, amax
return _default_get_amax( return _default_get_amax_and_update_history(
amax_history, amax_history,
recipe.amax_compute_algo, amax_compute_algo,
) )
...@@ -687,46 +658,29 @@ def _compute_scaling_factor( ...@@ -687,46 +658,29 @@ def _compute_scaling_factor(
return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)
def amax_and_scale_update( def _amax_and_scale_update(
fp8_meta: Dict[str, Any], amax_history: torch.Tensor,
fwd_update: bool, scale: torch.Tensor,
update_weight_scale_inv: bool = True, scale_inv: torch.Tensor,
fp8_max: float,
recipe: DelayedScaling,
) -> None: ) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd.""" """Updates FP8 meta tensors."""
amax_compute = fp8_meta["recipe"].amax_compute_algo new_amax_history, amax = _compute_amax_and_update_history(
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo amax_history,
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd" recipe.amax_compute_algo,
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
) = _fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
get_fp8_te_dtype(fp8_meta["recipe"], fwd_update),
fp8_meta["recipe"].margin,
fp8_meta["recipe"].amax_compute_algo,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
)
else:
fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta["recipe"],
)
fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor(
amax,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_max_key],
fp8_meta["recipe"],
)
fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse(
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
update_weight_scale_inv,
) )
new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe)
scale.copy_(new_scale)
scale_inv.copy_(1.0 / new_scale)
amax_history.copy_(new_amax_history)
def split_and_copy(
buffer: torch.Tensor,
outputs: List[torch.Tensor],
chunk_sizes: List[int],
) -> None:
"""Split `buffer` by `chunk_sizes` and copy into `outputs`."""
splits = buffer.split(chunk_sizes)
torch._foreach_copy_(outputs, splits)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Functions for CUDA Graphs support in FP8"""
import torch
from torch.utils._pytree import tree_flatten as _tree_flatten
from torch.utils._pytree import tree_unflatten as _tree_unflatten
from torch._C import _graph_pool_handle
from .fp8 import (
fp8_autocast,
FP8GlobalStateManager,
get_default_fp8_recipe,
)
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule
__all__ = ["make_graphed_callables"]
_IS_GRAPH_CAPTURING = False
def set_capture_start() -> None:
"""Record beginning of `make_graphed_callables`."""
global _IS_GRAPH_CAPTURING
_IS_GRAPH_CAPTURING = True
def set_capture_end() -> None:
"""Record end of `make_graphed_callables`."""
global _IS_GRAPH_CAPTURING
_IS_GRAPH_CAPTURING = False
def is_graph_capturing() -> None:
"""Return whether within `make_graphed_callables`."""
return _IS_GRAPH_CAPTURING
def graph_pool_handle():
"""
Returns an opaque token representing the id of a graph memory pool.
"""
return _graph_pool_handle()
def _make_graphed_callables(
callables,
sample_args,
num_warmup_iters=3,
allow_unused_input=False,
fp8_weight_caching=False,
_order=None,
):
"""
Helper method for `make_graphed_callables`
"""
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
raise RuntimeError(
"make_graphed_callables does not support the autocast "
"caching. Please set `cache_enabled=False`."
)
just_one_callable = False
if not isinstance(callables, tuple):
just_one_callable = True
callables = (callables,)
sample_args = (sample_args,)
flatten_sample_args = []
if _order is not None:
# order is a list containing 1..model_chunk values in the order of microbatch schedule
num_model_chunks = max(_order)
num_microbatches = len(_order) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order)
assert (
len(sample_args)*2 >= len(_order)
and (len(sample_args)*2 % len(_order) == 0)
), f'{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0'
num_layers = len(sample_args) // num_model_chunks // num_microbatches
assert (
len(callables) == num_model_chunks*num_layers
), (f"Callables should have ({num_model_chunks * num_layers}) "
+ f"entries when order input is provided but got {len(callables)}."
)
assert (
len(sample_args) == num_model_chunks * num_microbatches * num_layers
), (f"Expected {num_model_chunks * num_microbatches}"
+ f"args tuple, but got {len(sample_args)}."
)
if fp8_weight_caching:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
for c in callables:
if isinstance(c, torch.nn.Module):
assert (
len(c._backward_hooks) == 0
and len(c._forward_hooks) == 0
and len(c._forward_pre_hooks) == 0
), (
"Modules must not have hooks registered at the time they are passed. "
+ "However, registering hooks on modules after passing them "
+ "through make_graphed_callables is allowed."
)
assert all(b.requires_grad is False for b in c.buffers()), (
"In any :class:`~torch.nn.Module` passed to "
+ ":func:`~make_graphed_callables`, only parameters may be trainable. "
+ "All buffers must have ``requires_grad=False``."
)
for args in sample_args:
flatten_arg, _ = _tree_flatten(args)
flatten_sample_args.append(tuple(flatten_arg))
assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
"In the beta API, sample_args "
+ "for each callable must contain only Tensors. Other types are not allowed."
)
# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
per_callable_len_user_args = [len(args) for args in flatten_sample_args]
if _order is None:
per_callable_module_params = [
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
for c in callables
]
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
for i in range(len(callables))
]
else:
per_callable_module_params = []
for c in callables:
for i in range(num_microbatches):
per_callable_module_params.append(
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
)
assert len(per_callable_module_params) == len(flatten_sample_args)
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
for i in range(len(flatten_sample_args))
]
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
graph_callables = [None for _ in range(len(flatten_sample_args))]
# For cases with multiple active RNG states, e.g. TP.
if graph_safe_rng_available():
for _, state in get_all_rng_states().items():
for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs):
fwd_graph.register_generator_state(state)
bwd_graph.register_generator_state(state)
mempool = graph_pool_handle()
# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):
for c_i, func in enumerate(callables):
args = sample_args[c_i]
static_input_surface = per_callable_static_input_surfaces[c_i]
for _ in range(num_warmup_iters):
outputs, _ = _tree_flatten(func(*args))
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
only_inputs=True,
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
torch.cuda.synchronize()
# All captures here share a mempool. To avoid replays corrupting each other's memory,
# the safest approach is to capture all passes in the same order they'll run:
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
if _order is not None: # pylint: disable=too-many-nested-blocks
per_callable_static_outputs = [None] * len(flatten_sample_args)
per_callable_output_unflatten_spec = [None] * len(flatten_sample_args)
per_callable_static_grad_outputs = [None] * len(flatten_sample_args)
per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id-1
for l_no in range(num_layers):
func = callables[m_chunk*num_layers + l_no]
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) \
+ (fwd_idx[m_chunk] * num_layers + l_no)
args = sample_args[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
graph_callables[per_callable_fwd_idx] = func
fwd_idx[m_chunk] += 1
else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id-1
for l_no in list(reversed(range(num_layers))):
per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) \
+ (bwd_idx[m_chunk] * num_layers + l_no)
static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None) # type: ignore[arg-type]
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
bwd_idx[m_chunk] += 1
else:
# Capture forward graphs
per_callable_static_outputs = []
per_callable_output_unflatten_spec = []
graph_id = 0
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
graph_callables[graph_id] = func
graph_id += 1
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs.append(tuple(flatten_outputs))
per_callable_output_unflatten_spec.append(spec)
# Capture backward graphs in reverse order
per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph in zip(
reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs),
reversed(bwd_graphs),
):
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None) # type: ignore[arg-type]
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
per_callable_static_grad_outputs.append(static_grad_outputs)
per_callable_static_grad_inputs.append(static_grad_inputs)
# Reverses the most recent two lists
per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
def make_graphed_autograd_function(
fwd_graph,
bwd_graph,
module_params,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
static_grad_outputs,
static_grad_inputs,
):
class Graphed(torch.autograd.Function):
"""Autograd function for graph replay."""
@staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs):
# At this stage, only the user args may (potentially) be new tensors.
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
if ctx.is_first_module and skip_fp8_weight_update is not None:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *grads):
assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads):
if g is not None:
# don't copy if autograd gods have been kind and the
# incoming grad is already in the right place
if g.data_ptr() != grad.data_ptr():
g.copy_(grad)
bwd_graph.replay()
if ctx.is_first_module:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
# Input args that didn't require grad expect a None gradient.
assert isinstance(static_grad_inputs, tuple)
return (None,) + tuple(
b.detach() if b is not None else b for b in static_grad_inputs
)
def functionalized(*user_args, **user_kwargs):
# Runs the autograd function with inputs == all
# inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
skip_fp8_weight_update = None
if fp8_weight_caching:
assert (
("is_first_microbatch" in user_kwargs
and isinstance(user_kwargs["is_first_microbatch"], bool))
), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching."
skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]
flatten_user_args, _ = _tree_flatten(user_args)
out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params))
return _tree_unflatten(out, output_unflatten_spec)
return functionalized
# Put together the final graphed callables
ret = []
for i in range(len(sample_args)):
graphed = make_graphed_autograd_function(
fwd_graphs[i],
bwd_graphs[i],
per_callable_module_params[i],
per_callable_len_user_args[i],
per_callable_output_unflatten_spec[i],
per_callable_static_input_surfaces[i],
per_callable_static_outputs[i],
per_callable_static_grad_outputs[i],
per_callable_static_grad_inputs[i],
)
func = graph_callables[i]
if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
def new_fwd(*user_args, **user_kwargs):
# If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method
if func.training == graph_training_state:
# Set the FP8 group from global amax reduction.
for m in func.modules():
if (isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()):
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, fp8_weights=m._get_fp8_params())
return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs)
return new_fwd
forward = make_graphed_forward(func, func.training, graphed, func.forward)
if _order is None:
func.forward = forward
ret.append(func)
else:
ret.append(forward)
else:
ret.append(graphed)
if just_one_callable:
return ret[0]
return tuple(ret)
def save_fp8_tensors(modules, amax_history_len):
"""
Returns the FP8 tensors for all modules
with adjusted amax history sizes.
"""
saved_fp8_meta_tensors = []
for module in modules:
for m in module.modules():
if isinstance(m, TransformerEngineBaseModule):
if m.primary_weights_in_fp8:
m.adjust_amax_history_length(amax_history_len)
saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors())
return saved_fp8_meta_tensors
def restore_fp8_tensors(modules, fp8_tensors):
"""Restore FP8 tensors."""
for module in modules:
for m in module.modules():
if isinstance(m, TransformerEngineBaseModule):
m.reset_fp8_meta_tensors(fp8_tensors.pop(0))
assert len(fp8_tensors) == 0, "TE internal error."
def make_graphed_callables(
modules,
sample_args,
num_warmup_iters=3,
allow_unused_input=False,
fp8_enabled=False,
fp8_calibrating=False,
fp8_recipe=None,
fp8_weight_caching=False,
_order=None,
):
"""
A version of PyTorch's `make_graphed_callables` utility function with support for
TransformerEngine modules and FP8. Please see the original version in upstream PyTorch
`here <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
for extensive documentation. The documentation for additional parameters which are
specific to FP8 are given below.
FP8 specific parameters
-----------------------
fp8_enabled: bool, default = `True`
whether or not to enable fp8
fp8_calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_weight_caching: bool, default = `False`
Whether or not to cache FP8 weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
method for TransformerEngine modules. When storing primary weights in FP8
using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg
must be set to `False` if calculating weight transposes' outside TE, e.g.,
in the optimizer step.
"""
set_capture_start()
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
# Handle single module.
just_one_callable = False
if not isinstance(modules, tuple):
just_one_callable = True
modules = (modules,)
# Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len)
# FP8 wrapper.
def wrap_autocast(block):
old_forward = block.forward
def forward_func(*args, **kwargs):
with fp8_autocast(enabled=fp8_enabled,
calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe,
_graph=True):
outputs = old_forward(*args, **kwargs)
return outputs
block.forward = forward_func
forward_funcs = []
for module in modules:
assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported."
wrap_autocast(module)
forward_funcs.append(module)
if just_one_callable:
forward_funcs = forward_funcs[0]
else:
forward_funcs = tuple(forward_funcs)
# Save RNG state.
if graph_safe_rng_available():
generators = [torch.cuda.default_generators[torch.cuda.current_device()],
*get_all_rng_states().values()]
original_rng_states = [state.get_state() for state in generators]
else:
original_rng_states = torch.cuda.get_rng_state()
graphed_callables = _make_graphed_callables(
forward_funcs, sample_args, num_warmup_iters=num_warmup_iters,
allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching, _order=_order)
# Ensures warmup does not affect numerics for ops such as dropout.
if graph_safe_rng_available():
for gen, state in zip(generators, original_rng_states):
gen.set_state(state)
else:
torch.cuda.set_rng_state(original_rng_states)
# Reset FP8 gradients.
for module in modules:
for p in module.parameters():
p.grad = None
# Restore FP8 state.
restore_fp8_tensors(modules, saved_fp8_tensors)
set_capture_end()
return graphed_callables
...@@ -8,8 +8,7 @@ import os ...@@ -8,8 +8,7 @@ import os
import pickle import pickle
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generator, Union, Optional, Tuple, Dict, Any, List from typing import Generator, Union, Optional, Tuple, List
from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
...@@ -22,13 +21,11 @@ from ..fp8 import ( ...@@ -22,13 +21,11 @@ from ..fp8 import (
get_default_fp8_recipe, get_default_fp8_recipe,
get_fp8_te_dtype, get_fp8_te_dtype,
FP8GlobalStateManager, FP8GlobalStateManager,
amax_and_scale_update,
) )
from ..distributed import ( from ..distributed import (
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled, is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
get_distributed_world_size,
) )
from ..cpp_extensions import ( from ..cpp_extensions import (
fp8_cast_transpose_fused, fp8_cast_transpose_fused,
...@@ -44,7 +41,6 @@ _2X_ACC_WGRAD = True ...@@ -44,7 +41,6 @@ _2X_ACC_WGRAD = True
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 3 _NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -64,49 +60,6 @@ def get_workspace() -> torch.Tensor: ...@@ -64,49 +60,6 @@ def get_workspace() -> torch.Tensor:
) )
return _cublas_workspace return _cublas_workspace
@contextmanager
def _prepare_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> Generator[None, None, None]:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction
if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
# From previous iteration
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
else:
amax_and_scale_update(fp8_meta, False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if (fp8 and fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(fp8_meta["fp8_group"]) > 1):
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False)
def initialize_ub( def initialize_ub(
shape: list, shape: list,
...@@ -300,31 +253,54 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -300,31 +253,54 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_size = 1 self.tp_size = 1
self.sequence_parallel = False self.sequence_parallel = False
self.fp8_weight_shapes = [] self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
)
self.param_init_meta = {} self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
.. warning::
This changes the underlying amax memory location.
"""
if fwd is None:
fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
else:
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
for meta_key in fp8_meta_tensor_keys:
curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len:
continue
if length < curr_len:
self.fp8_meta[meta_key].amax_history = (
self.fp8_meta[meta_key].amax_history[: length].clone())
elif length > curr_len:
extra_rows = length - curr_len
self.fp8_meta[meta_key].amax_history = F.pad(
self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
)
# Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = (
self.fp8_meta[FP8GlobalStateManager.get_buffer_info()])
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history[0])
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history)
def set_meta_tensor(self, fwd: bool) -> None: def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
if self.fp8_meta_tensors_initialized: if self.fp8_meta_tensors_initialized:
# Handle changed amax history size. # Handle changed amax history size.
curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0] self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd)
need_len = self.fp8_meta["recipe"].amax_history_len
if need_len < curr_len:
self.fp8_meta[fp8_meta_tensor_key].amax_history = (
self.fp8_meta[fp8_meta_tensor_key]
.amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
)
elif need_len > curr_len:
extra_rows = need_len - curr_len
self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
)
return return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
...@@ -347,25 +323,45 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -347,25 +323,45 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
device="cuda", device="cuda",
) )
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False, True]: -> [input, weight, output]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False, True] * self.fp8_meta["num_gemms"]
).cuda()
else:
# [True, True]: -> [grad_output, grad_input]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None: def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes.""" """Init scales and amaxes."""
self.set_meta_tensor(True) self.set_meta_tensor(True)
self.set_meta_tensor(False) self.set_meta_tensor(False)
self.fp8_meta_tensors_initialized = True self.fp8_meta_tensors_initialized = True
def get_fp8_meta_tensors(self) -> None:
"""Get scales and amaxes."""
fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
return None
fp8_meta_tensors = {fwd_key: [], bwd_key: []}
with torch.no_grad():
for key in (fwd_key, bwd_key):
fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
return fp8_meta_tensors
def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
"""Reset scales and amaxes."""
def reset(key):
if key in self.fp8_meta:
if fp8_meta_tensors is None:
self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
self.fp8_meta[key].scale_inv.copy_(
torch.ones_like(self.fp8_meta[key].scale_inv))
self.fp8_meta[key].amax_history.copy_(
torch.zeros_like(self.fp8_meta[key].amax_history))
else:
assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1])
self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2])
with torch.no_grad():
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor: def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing.""" """Save before checkpointing."""
state = None state = None
...@@ -380,13 +376,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -380,13 +376,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint()
state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint()
# Store other pickelable values. # Store other pickelable values.
extra = {} extra = {}
for k, v in self.fp8_meta.items(): for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str, list)): if isinstance(v, (bool, int, float, str, tuple, list)):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
...@@ -414,11 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -414,11 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# Restore global FP8 amax buffer.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state.
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
# Load extra items. # Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
...@@ -527,6 +516,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -527,6 +516,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_group = tp_group self.tp_group = tp_group
self.tp_group_initialized = True self.tp_group_initialized = True
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights."""
fp8_params = []
for param in self.parameters():
if isinstance(param, Float8Tensor) and param.requires_grad:
fp8_params.append(param)
if len(fp8_params) == 0:
return None
return fp8_params
# This routine is shared across FP8 and FP8_calibration paths so should not actually # This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution. # assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None: def init_fp8_metadata(self, num_gemms: int = 1) -> None:
...@@ -576,7 +575,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -576,7 +575,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one. just in case. The autocast exit will pick up the most recent one.
""" """
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
...@@ -594,49 +592,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -594,49 +592,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if is_first_microbatch is not None and not self.primary_weights_in_fp8: if is_first_microbatch is not None and not self.primary_weights_in_fp8:
self.set_fp8_weights() self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel: if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \ assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \ "Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8." "necessary when using sequence parallelism with FP8."
# Previous iteration was grad_enabled if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
if self.fp8_meta.get("update_amax_and_scale_fwd", False): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
if (self.fp8_meta["recipe"].reduce_amax self.fp8_meta, fp8_weights=self._get_fp8_params())
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.new_fp8_context_id())
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.get_fp8_context_id())
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if ( if (
...@@ -653,18 +616,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -653,18 +616,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return return
if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
FP8GlobalStateManager.global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the before the GEMM for there to be a guaranteed overlap. From the
......
...@@ -14,7 +14,6 @@ from .. import cpp_extensions as tex ...@@ -14,7 +14,6 @@ from .. import cpp_extensions as tex
from .base import ( from .base import (
get_workspace, get_workspace,
_prepare_backward,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -65,6 +64,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -65,6 +64,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias: bool, use_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
...@@ -89,6 +89,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -89,6 +89,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_rs_dgrad: bool, ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -98,7 +99,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -98,7 +99,11 @@ class _LayerNormLinear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
...@@ -196,7 +201,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -196,7 +201,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Weight is already in FP8 # Weight is already in FP8
weight.reset_fp8_meta_scale_inv() weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
weight_fp8 = Float8Tensor( weight_fp8 = Float8Tensor(
...@@ -214,6 +218,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -214,6 +218,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
cast_out=weight_fp8._data, cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data, transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
) )
else: else:
tex.cast_to_fp8( tex.cast_to_fp8(
...@@ -295,6 +300,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -295,6 +300,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_t_fp8, weight_t_fp8,
ln_out if weight.requires_grad else None, ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
...@@ -321,6 +327,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -321,6 +327,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -344,9 +351,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -344,9 +351,7 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward( with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
( (
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -357,6 +362,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -357,6 +362,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_t_fp8, weight_t_fp8,
ln_out, ln_out,
fwd_scale_inverses, fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
...@@ -364,10 +370,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -364,10 +370,13 @@ class _LayerNormLinear(torch.autograd.Function):
weight.main_grad = main_grad weight.main_grad = main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None: if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose( weight_t_fp8 = weight.transpose_2d(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
) )
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False ctx.ub_bulk_dgrad = False
...@@ -472,7 +481,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -472,7 +481,7 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
weight_t_fp8._data, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -686,6 +695,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -686,6 +695,8 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -970,7 +981,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -970,7 +981,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata() self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
...@@ -990,6 +1000,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -990,6 +1000,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -1084,6 +1098,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1084,6 +1098,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced) produced)
""" """
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
...@@ -1132,6 +1150,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1132,6 +1150,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
skip_fp8_weight_update,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
...@@ -1156,6 +1175,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1156,6 +1175,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad, self.ub_overlap_rs_dgrad,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
self.dummy_tensor,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -13,7 +13,6 @@ from torch.nn import init ...@@ -13,7 +13,6 @@ from torch.nn import init
from .base import ( from .base import (
get_workspace, get_workspace,
_prepare_backward,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -94,6 +93,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -94,6 +93,7 @@ class _LayerNormMLP(torch.autograd.Function):
use_fc2_bias: bool, use_fc2_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
...@@ -121,6 +121,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -121,6 +121,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
gemm_gelu_fusion: bool, gemm_gelu_fusion: bool,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -131,7 +132,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -131,7 +132,11 @@ class _LayerNormMLP(torch.autograd.Function):
assert_dim_for_fp8_exec(fc1_weight) assert_dim_for_fp8_exec(fc1_weight)
assert_dim_for_fp8_exec(fc2_weight) assert_dim_for_fp8_exec(fc2_weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
activation_func = _act_func(activation)[0] activation_func = _act_func(activation)[0]
...@@ -225,8 +230,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -225,8 +230,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.reset_fp8_meta_scale_inv() fc2_weight.reset_fp8_meta_scale_inv()
fc1_weight_fp8 = fc1_weight fc1_weight_fp8 = fc1_weight
fc2_weight_fp8 = fc2_weight fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor( fc1_weight_fp8 = Float8Tensor(
...@@ -250,6 +253,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -250,6 +253,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
cast_out=fc1_weight_fp8._data, cast_out=fc1_weight_fp8._data,
transpose_out=fc1_weight_t_fp8._data, transpose_out=fc1_weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
) )
tex.fp8_cast_transpose_fused( tex.fp8_cast_transpose_fused(
fc2_weight, fc2_weight,
...@@ -258,6 +262,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -258,6 +262,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
cast_out=fc2_weight_fp8._data, cast_out=fc2_weight_fp8._data,
transpose_out=fc2_weight_t_fp8._data, transpose_out=fc2_weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
) )
else: else:
tex.cast_to_fp8( tex.cast_to_fp8(
...@@ -510,6 +515,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -510,6 +515,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc1_bias, fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.activation = activation ctx.activation = activation
...@@ -538,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -538,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_overlap_ag = ub_overlap_ag ctx.ub_overlap_ag = ub_overlap_ag
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
...@@ -563,9 +570,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -563,9 +570,7 @@ class _LayerNormMLP(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward( with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
):
( (
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -582,6 +587,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -582,6 +587,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc1_bias, fc1_bias,
fwd_scale_inverses, fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
...@@ -592,11 +598,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -592,11 +598,18 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.main_grad = fc2_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy" if ctx.primary_weights_in_fp8:
if ctx.fp8 and fc1_weight_t_fp8 is None: fc1_weight_t_fp8 = fc1_weight.transpose_2d(
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache) cache=ctx.is_first_microbatch is not None,
if ctx.fp8 and fc2_weight_t_fp8 is None: noop_flag=skip_fp8_weight_update,
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache) )
fc2_weight_t_fp8 = fc2_weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
fc1_weight_t_fp8 = fc1_weight_t_fp8._data
fc2_weight_t_fp8 = fc2_weight_t_fp8._data
activation_func = _act_func(ctx.activation)[1] activation_func = _act_func(ctx.activation)[1]
...@@ -673,7 +686,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -673,7 +686,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm( fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8._data, fc2_weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -826,7 +839,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -826,7 +839,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj = None ub_obj = None
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
fc1_weight_t_fp8._data, fc1_weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -1151,6 +1164,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1151,6 +1164,8 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -1389,7 +1404,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1389,7 +1404,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2) self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
...@@ -1414,6 +1428,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1414,6 +1428,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -1473,7 +1491,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1473,7 +1491,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply layer normalization to the input followed by a feedforward network (MLP Block). Apply layer normalization to the input followed by a feedforward network (MLP Block).
...@@ -1497,6 +1517,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1497,6 +1517,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced) produced)
""" """
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
...@@ -1535,6 +1559,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1535,6 +1559,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
skip_fp8_weight_update,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
...@@ -1562,6 +1587,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1562,6 +1587,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.gemm_gelu_fusion, self.gemm_gelu_fusion,
self.dummy_tensor,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -11,7 +11,6 @@ import transformer_engine_extensions as tex ...@@ -11,7 +11,6 @@ import transformer_engine_extensions as tex
from .base import ( from .base import (
get_workspace, get_workspace,
_prepare_backward,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -65,6 +64,7 @@ class _Linear(torch.autograd.Function): ...@@ -65,6 +64,7 @@ class _Linear(torch.autograd.Function):
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool, use_bias: bool,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
...@@ -80,7 +80,8 @@ class _Linear(torch.autograd.Function): ...@@ -80,7 +80,8 @@ class _Linear(torch.autograd.Function):
primary_weights_in_fp8: bool, primary_weights_in_fp8: bool,
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -90,7 +91,12 @@ class _Linear(torch.autograd.Function): ...@@ -90,7 +91,12 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
...@@ -140,7 +146,6 @@ class _Linear(torch.autograd.Function): ...@@ -140,7 +146,6 @@ class _Linear(torch.autograd.Function):
# Weight is already in FP8 # Weight is already in FP8
weight.reset_fp8_meta_scale_inv() weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
weight_fp8 = Float8Tensor( weight_fp8 = Float8Tensor(
...@@ -158,6 +163,7 @@ class _Linear(torch.autograd.Function): ...@@ -158,6 +163,7 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
cast_out=weight_fp8._data, cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data, transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
) )
else: else:
cast_to_fp8( cast_to_fp8(
...@@ -296,6 +302,7 @@ class _Linear(torch.autograd.Function): ...@@ -296,6 +302,7 @@ class _Linear(torch.autograd.Function):
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None, weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
...@@ -313,6 +320,7 @@ class _Linear(torch.autograd.Function): ...@@ -313,6 +320,7 @@ class _Linear(torch.autograd.Function):
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
...@@ -330,9 +338,7 @@ class _Linear(torch.autograd.Function): ...@@ -330,9 +338,7 @@ class _Linear(torch.autograd.Function):
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward( with torch.cuda.nvtx.range("_Linear_backward"):
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
( (
inputmat, inputmat,
inputmat_t, inputmat_t,
...@@ -340,6 +346,7 @@ class _Linear(torch.autograd.Function): ...@@ -340,6 +346,7 @@ class _Linear(torch.autograd.Function):
main_grad, main_grad,
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
...@@ -347,10 +354,14 @@ class _Linear(torch.autograd.Function): ...@@ -347,10 +354,14 @@ class _Linear(torch.autograd.Function):
weight.main_grad = main_grad weight.main_grad = main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None: if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose( weight_t_fp8 = weight.transpose_2d(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
) )
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
...@@ -361,6 +372,7 @@ class _Linear(torch.autograd.Function): ...@@ -361,6 +372,7 @@ class _Linear(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
( (
grad_output, grad_output,
grad_output_c, grad_output_c,
...@@ -401,7 +413,7 @@ class _Linear(torch.autograd.Function): ...@@ -401,7 +413,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
dgrad, _ = fp8_gemm( dgrad, _ = fp8_gemm(
weight_t_fp8._data, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -542,6 +554,8 @@ class _Linear(torch.autograd.Function): ...@@ -542,6 +554,8 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -772,7 +786,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -772,7 +786,6 @@ class Linear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata() self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
...@@ -785,6 +798,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -785,6 +798,10 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init) super().reset_parameters(defer_init=defer_init)
...@@ -858,6 +875,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -858,6 +875,10 @@ class Linear(TransformerEngineBaseModule):
produced) produced)
""" """
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
...@@ -903,6 +924,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -903,6 +924,7 @@ class Linear(TransformerEngineBaseModule):
bias_tensor, bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch, is_first_microbatch,
skip_fp8_weight_update,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
...@@ -919,6 +941,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -919,6 +941,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
self.dummy_tensor,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
...@@ -473,6 +473,15 @@ class TransformerLayer(torch.nn.Module): ...@@ -473,6 +473,15 @@ class TransformerLayer(torch.nn.Module):
if hasattr(child, "set_tensor_parallel_group"): if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group) child.set_tensor_parallel_group(tp_group)
def reset_fp8_meta_tensors(self) -> None:
"""Set TP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "reset_fp8_meta_tensors"):
child.reset_fp8_meta_tensors()
def set_context_parallel_group( def set_context_parallel_group(
self, self,
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
...@@ -665,7 +674,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -665,7 +674,8 @@ class TransformerLayer(torch.nn.Module):
# MLP. # MLP.
mlp_outputs = self.layernorm_mlp( mlp_outputs = self.layernorm_mlp(
hidden_states, is_first_microbatch=is_first_microbatch hidden_states,
is_first_microbatch=is_first_microbatch,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs mlp_output, mlp_bias, residual = mlp_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment