Unverified Commit 73f8d90f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] cuda graph support (#575)



* FP8 cuda graphs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>

* Fix numerics
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* exclude torch compile from numerics tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More numerics fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm fusion from unfused path
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarCharlene Yang <charleney@nvidia.com>
parent 1b20f2d6
...@@ -157,7 +157,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -157,7 +157,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
...@@ -238,13 +238,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -238,13 +238,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _stop_comm, 0));
}
if (A_scale_inverse.numel()) if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor]; A_scale_inverse = A_scale_inverse[A_fp8_tensor];
...@@ -350,11 +347,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -350,11 +347,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (int i = 0; i < _stream_compute.size(); i++) { for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
} }
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
if (A_scale_inverse.numel()) if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor]; A_scale_inverse = A_scale_inverse[A_fp8_tensor];
...@@ -469,13 +467,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -469,13 +467,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
} }
} }
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
_ub_comm->sms = ori_sms; _ub_comm->sms = ori_sms;
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main); at::cuda::setCurrentCUDAStream(stream_main);
...@@ -506,7 +504,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -506,7 +504,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
} }
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(),
...@@ -805,14 +803,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -805,14 +803,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_scale_inverse.numel()) if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor]; B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
if (_aggregate2) { if (_aggregate2) {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
const int num_steps = _tp_size / 2; const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
...@@ -877,21 +876,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -877,21 +876,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
} }
} }
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id =
(num_steps + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
} else { } else {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _tp_size; i++) { for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current // Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
...@@ -936,16 +923,19 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -936,16 +923,19 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
} }
} }
at::cuda::setCurrentCUDAStream(stream_main); }
int last_compute_stream_id = (_tp_size + _stream_compute.size() - 1) % _stream_compute.size(); for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA( CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
} }
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return D; return D;
} // split_overlap_ag } // split_overlap_ag
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include <transformer_engine/softmax.h> #include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <transformer_engine/cast_transpose_noop.h>
namespace transformer_engine { namespace transformer_engine {
......
...@@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input, ...@@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input,
); );
void fused_cast_transpose_noop(at::Tensor input,
at::Tensor noop,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
);
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -263,6 +274,17 @@ at::Tensor fp8_transpose(at::Tensor input, ...@@ -263,6 +274,17 @@ at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype transformer_engine::DType otype
); );
void fp8_transpose_noalloc(at::Tensor input,
at::Tensor output,
transformer_engine::DType otype
);
void fp8_transpose_noalloc_noop(at::Tensor input,
at::Tensor output,
at::Tensor noop,
transformer_engine::DType otype
);
/*************************************************************************************************** /***************************************************************************************************
* Activations * Activations
**************************************************************************************************/ **************************************************************************************************/
...@@ -559,16 +581,13 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads ...@@ -559,16 +581,13 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
* FP8 recipe * FP8 recipe
**************************************************************************************************/ **************************************************************************************************/
void fused_amax_and_scale_update(const at::Tensor &amax_history, void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
const at::Tensor &scale, std::vector<at::Tensor> amax_histories,
const at::Tensor &scale_inv, std::vector<at::Tensor> scales,
const at::Tensor &scale_inv_mask, std::vector<at::Tensor> scale_invs,
at::Tensor updated_amax_history, const std::string &amax_compute_algo,
at::Tensor updated_scale, transformer_engine::DType fp8_dtype,
at::Tensor updated_scale_inv, float margin);
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin);
/*************************************************************************************************** /***************************************************************************************************
* Rotary positional embedding * Rotary positional embedding
......
...@@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop,
"Fused Cast + Transpose with noop option");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD"); "Fused Cast + Transpose + BGRAD");
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad,
...@@ -67,6 +69,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -67,6 +69,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_attn_bwd", &fused_attn_bwd, m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O");
m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop,
"Transpose with FP8 I/O with noop option.");
m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("gelu", &gelu, "GeLU with FP8 output");
m.def("relu", &relu, "ReLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output");
m.def("geglu", &geglu, "GeGLU with FP8 output"); m.def("geglu", &geglu, "GeGLU with FP8 output");
...@@ -82,9 +87,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -82,9 +87,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
m.def("fused_amax_and_scale_update", m.def("fused_amax_and_scale_update_after_reduction",
&fused_amax_and_scale_update, &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale"); "Update amax history and FP8 scale/scale_inv after reduction");
// fused apply rope // fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD"); m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD");
......
...@@ -11,24 +11,50 @@ ...@@ -11,24 +11,50 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
void fused_amax_and_scale_update(const at::Tensor &amax_history,
const at::Tensor &scale, void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
const at::Tensor &scale_inv, std::vector<at::Tensor> amax_histories,
const at::Tensor &scale_inv_mask, std::vector<at::Tensor> scales,
at::Tensor updated_amax_history, std::vector<at::Tensor> scale_invs,
at::Tensor updated_scale, const std::string &amax_compute_algo,
at::Tensor updated_scale_inv, transformer_engine::DType fp8_dtype,
const std::string& amax_compute_algo, float margin) {
transformer_engine::DType fp8_dtype, using namespace transformer_engine;
float margin) { size_t num_tensors = amax_histories.size();
nvte_delayed_scaling_recipe_amax_and_scale_update( std::vector<Tensor> t_amax_histories(num_tensors);
makeTransformerEngineTensor(amax_history).data(), std::vector<Tensor> t_scales(num_tensors);
makeTransformerEngineTensor(scale).data(), std::vector<Tensor> t_scale_invs(num_tensors);
makeTransformerEngineTensor(scale_inv).data(), std::vector<NVTETensor> te_amax_histories(num_tensors);
makeTransformerEngineTensor(scale_inv_mask).data(), std::vector<NVTETensor> te_scales(num_tensors);
makeTransformerEngineTensor(updated_amax_history).data(), std::vector<NVTETensor> te_scale_invs(num_tensors);
makeTransformerEngineTensor(updated_scale).data(), for (size_t i = 0; i < num_tensors; i++) {
makeTransformerEngineTensor(updated_scale_inv).data(), t_amax_histories[i].data.dptr = amax_histories[i].data_ptr();
auto amax_sizes = amax_histories[i].sizes().vec();
std::vector<size_t> amax_shape{amax_sizes.begin(), amax_sizes.end()};
t_amax_histories[i].data.shape = amax_shape;
t_amax_histories[i].data.dtype = DType::kFloat32;
t_scales[i].data.dptr = scales[i].data_ptr();
auto scale_sizes = scales[i].sizes().vec();
std::vector<size_t> scale_shape{scale_sizes.begin(), scale_sizes.end()};
t_scales[i].data.shape = scale_shape;
t_scales[i].data.dtype = DType::kFloat32;
t_scale_invs[i].data.dptr = scale_invs[i].data_ptr();
auto scale_inv_sizes = scale_invs[i].sizes().vec();
std::vector<size_t> scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()};
t_scale_invs[i].data.shape = scale_inv_shape;
t_scale_invs[i].data.dtype = DType::kFloat32;
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
te_scale_invs[i] = reinterpret_cast<NVTETensor>(&t_scale_invs[i]);
}
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(),
te_amax_histories,
te_scales,
te_scale_invs,
amax_compute_algo.c_str(), amax_compute_algo.c_str(),
static_cast<NVTEDType>(fp8_dtype), static_cast<NVTEDType>(fp8_dtype),
margin, margin,
......
...@@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input, ...@@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input,
} }
void fused_cast_transpose_noop(at::Tensor input,
at::Tensor noop,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(),
output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
...@@ -319,3 +348,39 @@ at::Tensor fp8_transpose(at::Tensor input, ...@@ -319,3 +348,39 @@ at::Tensor fp8_transpose(at::Tensor input,
return output; return output;
} }
void fp8_transpose_noalloc(at::Tensor input,
at::Tensor output,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void fp8_transpose_noalloc_noop(at::Tensor input,
at::Tensor output,
at::Tensor noop,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose_with_noop(
input_cu.data(), noop_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
}
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
"""Methods needed for distributed training (DP/TP).""" """Methods needed for distributed training (DP/TP)."""
import warnings import warnings
from contextlib import contextmanager, AbstractContextManager, ContextDecorator from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from typing import Any, Dict, Union, Optional, Callable, Tuple from typing import Any, Dict, List, Union, Optional, Callable, Tuple
import torch import torch
from torch.cuda import _lazy_call from torch.cuda import _lazy_call, _lazy_init
from torch.utils.checkpoint import detach_variable, noop_context_fn from torch.utils.checkpoint import detach_variable, noop_context_fn
from .utils import safely_set_viewless_tensor_data from .utils import safely_set_viewless_tensor_data
...@@ -31,15 +31,60 @@ _FP8_ACTIVATION_RECOMPUTE_ENABLED = False ...@@ -31,15 +31,60 @@ _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False _FP8_ACTIVATION_RECOMPUTE_PHASE = False
def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None: _ALL_ACTIVE_RNG_STATES = {}
"""Sets the random number generator state of the current GPU.
def get_all_rng_states() -> bool:
"""Returns all generator states used by `CudaRNGStatesTracker`."""
return _ALL_ACTIVE_RNG_STATES
def set_all_rng_states(states: List) -> None:
"""Updates all generator states used by `CudaRNGStatesTracker`."""
global _ALL_ACTIVE_RNG_STATES
_ALL_ACTIVE_RNG_STATES = states
def graph_safe_rng_available() -> bool:
"""Returns whether cuda graph safe RNG state manipulation is supported."""
return (hasattr(torch.cuda.CUDAGraph, "register_generator_state")
and hasattr(torch.Generator, "graphsafe_set_state")
and hasattr(torch.Generator, "graphsafe_get_state")
and hasattr(torch.Generator, "clone_state"))
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda",
clone: bool = False,
graph_safe: bool = True,
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU."""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
if clone:
# Reference to the cloned generator state
return default_generator.clone_state()
# Reference to the current generator state
return default_generator.graphsafe_get_state()
return default_generator.get_state()
def _set_cuda_rng_state(
new_state: torch.Tensor,
device: Union[int, str] = -1,
graph_safe = True,
) -> None:
"""Sets the random number generator state of the current GPU."""
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if device == -1: if device == -1:
device = torch.device("cuda") device = torch.device("cuda")
elif isinstance(device, str): elif isinstance(device, str):
...@@ -52,6 +97,9 @@ def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) - ...@@ -52,6 +97,9 @@ def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -
if idx is None: if idx is None:
idx = torch.cuda.current_device() idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx] default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
default_generator.graphsafe_set_state(new_state)
return
default_generator.set_state(new_state) default_generator.set_state(new_state)
_lazy_call(cb) _lazy_call(cb)
...@@ -206,7 +254,7 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -206,7 +254,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
...@@ -271,13 +319,13 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -271,13 +319,13 @@ class _CheckpointFunction(torch.autograd.Function):
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state) torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
...@@ -291,7 +339,7 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -291,7 +339,7 @@ class _CheckpointFunction(torch.autograd.Function):
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state) torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None: if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
...@@ -317,6 +365,7 @@ class _CheckpointFunction(torch.autograd.Function): ...@@ -317,6 +365,7 @@ class _CheckpointFunction(torch.autograd.Function):
) )
return (None, None, None, None, None, None) + grads return (None, None, None, None, None, None) + grads
class _CheckpointFrame: class _CheckpointFrame:
""" """
Storage frame for forward RNG states and detached activations from the forward recompute. Storage frame for forward RNG states and detached activations from the forward recompute.
...@@ -338,7 +387,7 @@ class _CheckpointFrame: ...@@ -338,7 +387,7 @@ class _CheckpointFrame:
"""Cache fwd/bwd RNG states in the frame to restore later.""" """Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = ( rng_states = (
torch.get_rng_state(), torch.get_rng_state(),
torch.cuda.get_rng_state(), _get_cuda_rng_state(graph_safe=False),
) )
if self.get_rng_state_tracker is not None: if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(), ) rng_states += (self.get_rng_state_tracker().get_states(), )
...@@ -356,7 +405,7 @@ class _CheckpointFrame: ...@@ -356,7 +405,7 @@ class _CheckpointFrame:
rng_states = self.bwd_rng_states rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0]) torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1]) _set_cuda_rng_state(rng_states[1], graph_safe=False)
if self.get_rng_state_tracker is not None: if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2]) self.get_rng_state_tracker().set_states(rng_states[2])
...@@ -604,6 +653,7 @@ def checkpoint( ...@@ -604,6 +653,7 @@ def checkpoint(
return out return out
class CudaRNGStatesTracker: class CudaRNGStatesTracker:
""" """
For model parallelism, multiple RNG states need to simultaneously exist in order For model parallelism, multiple RNG states need to simultaneously exist in order
...@@ -664,13 +714,23 @@ class CudaRNGStatesTracker: ...@@ -664,13 +714,23 @@ class CudaRNGStatesTracker:
# Check that state is not already defined. # Check that state is not already defined.
if name in self.states_: if name in self.states_:
raise Exception(f"cuda rng state {name} already exists") raise Exception(f"cuda rng state {name} already exists")
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state() if graph_safe_rng_available():
# Set the new state and store it. new_state = _get_cuda_rng_state(clone=True)
torch.cuda.manual_seed(seed) new_state.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state() self.states_[name] = new_state
# Reset rng state to what it was. # Update global states.
_set_cuda_rng_state(orig_rng_state) set_all_rng_states(self.states_)
else:
# Get the current rng state.
orig_rng_state = _get_cuda_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = _get_cuda_rng_state(clone=True)
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
# Update global states.
set_all_rng_states(self.states_)
@contextmanager @contextmanager
def fork(self, name: str = "model-parallel-rng"): def fork(self, name: str = "model-parallel-rng"):
...@@ -684,16 +744,17 @@ class CudaRNGStatesTracker: ...@@ -684,16 +744,17 @@ class CudaRNGStatesTracker:
# Check if we have added the state # Check if we have added the state
if name not in self.states_: if name not in self.states_:
raise Exception(f"cuda rng state {name} is not added") raise Exception(f"cuda rng state {name} is not added")
# Store current rng state. # Get the reference to current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state() orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one # Set rng state to the desired one
_set_cuda_rng_state(self.states_[name]) _set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do. # Do the stuff we wanted to do.
try: try:
yield yield
finally: finally:
# Update the current rng state for later use. # this is redundant with graph-safe API
self.states_[name] = torch.cuda.get_rng_state() if not graph_safe_rng_available():
self.states_[name] = _get_cuda_rng_state()
# And set the state to the original state we started with. # And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state) _set_cuda_rng_state(orig_cuda_rng_state)
......
...@@ -16,6 +16,7 @@ from .fp8 import FP8GlobalStateManager ...@@ -16,6 +16,7 @@ from .fp8 import FP8GlobalStateManager
aten = torch.ops.aten aten = torch.ops.aten
c10d = torch.ops.c10d c10d = torch.ops.c10d
updated_fp8_params = {}
def _make_fp8_attr_property_funcs(name: str) -> Any: def _make_fp8_attr_property_funcs(name: str) -> Any:
...@@ -67,6 +68,31 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -67,6 +68,31 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None return grad, None
def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:
"""Amax scale and update when there is at least 1 trainable FP8 parameter."""
param_id = id(param._data)
if param_id not in FP8GlobalStateManager.fp8_param_to_autocast:
return
autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id]
if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params:
return
if autocast_key in updated_fp8_params:
updated_fp8_params[autocast_key].add(param_id)
else:
updated_fp8_params[autocast_key] = {param_id}
current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key]
# All FP8 trainable parameters have been updated.
if updated_fp8_params[autocast_key] == current_fp8_params_set:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(
forward=True, fp8_weights=True)
del updated_fp8_params[autocast_key]
class _ToFloat8Func(torch.autograd.Function): class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Cast to FP8 from other dtype"""
@staticmethod @staticmethod
...@@ -167,6 +193,7 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -167,6 +193,7 @@ class _ToFloat8Func(torch.autograd.Function):
# Assume that we want gradients in full precision # Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None return grad, None, None, None, None, None, None, None
class _IdentityFunc(torch.autograd.Function): class _IdentityFunc(torch.autograd.Function):
"""Identity function """Identity function
...@@ -307,8 +334,9 @@ class Float8Tensor(torch.Tensor): ...@@ -307,8 +334,9 @@ class Float8Tensor(torch.Tensor):
), f"Unsupported fp8_dtype {fp8_dtype}." ), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: tex.DType = fp8_dtype self._fp8_dtype: tex.DType = fp8_dtype
# Cached transpose # Transposed version of `_data`.
self._transpose: Optional[Float8Tensor] = None self._transpose: Optional[Float8Tensor] = None
self._transpose_invalid: bool = True
# FP8 scale-inverse # FP8 scale-inverse
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
...@@ -435,80 +463,51 @@ class Float8Tensor(torch.Tensor): ...@@ -435,80 +463,51 @@ class Float8Tensor(torch.Tensor):
return _IdentityFunc.apply(self) return _IdentityFunc.apply(self)
return super().expand_as(other) return super().expand_as(other)
def transpose( def transpose_2d(
self, self,
dim0: int = 0,
dim1: int = 1,
*, *,
update_cache: str | bool = "reuse_only", cache: bool = False,
noop_flag: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Swap tensor dimensions 2D transpose with caching support.
For basic 2D matrix transposes, an optimized transpose kernel
is applied and a Float8Tensor is returned.
Parameters Parameters
---------- ----------
dim0: int, default = 0 cache: bool, default = `False`
The first dimension to be transposed Whether or not to cache the transpose.
dim1: int, default = 1 noop_flag: Optional[torch.Tensor], default = `None`
The second dimension to be transposed Only used if argument `cache` is `True`, ignored otherwise.
update_cache: str or bool, default = "reuse_only" A single element fp32 tensor with a value of 1.0 or 0.0
Memoization behavior. Options are which is treated as a boolean. `1.0` forces recompute
"reuse_only"/`False` (reuse cached value if and `0.0` executes a noop using the same kernel.
available, otherwise calculate transpose without
caching), "force"/`True` (calculate transpose
and cache), "lazy" (reuse cached value if
available, otherwise calculate transpose and
cache if possible). Caching is only supported
for basic 2D transposes and the cache is reset
after any in-place operations.
""" """
assert self.dim() == 2, f"{self.dim()}-D transpose not supported."
# Check caching mode # Case: no caching.
if not isinstance(update_cache, str): if not cache:
update_cache = "force" if update_cache else "reuse_only" return tex.fp8_transpose(self._data, self._fp8_dtype)
if update_cache not in ("force", "reuse_only", "lazy"):
raise ValueError(
"Supported values for update_cache are "
'"force" (True), "reuse_only" (False), "lazy" '
f"(got {update_cache})"
)
# Handle non-2D transposes # Case: reuse cache without calling a kernel.
if -self.dim() <= dim0 < 0: if not self._transpose_invalid and noop_flag is None:
dim0 += self.dim() assert self._transpose is not None, "Tranpose cache is empty."
if -self.dim() <= dim1 < 0: return self._transpose
dim1 += self.dim()
if self.dim() != 2 or dim0 == dim1:
if update_cache == "force":
raise ValueError(
"Transpose caching is only supported for basic 2D transposes "
f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})"
)
return super().transpose(dim0, dim1)
# Clear cache if needed
if update_cache == "force":
self._transpose = None
# Compute transpose if needed
out = self._transpose
if out is None:
out = Float8Tensor.make_like(
self,
data=tex.fp8_transpose(
self._data.contiguous(),
self._fp8_dtype,
),
)
# Update cache if needed # Allocate transpose if needed.
if update_cache in ("force", "lazy"): data_2d = self._data.reshape(-1, self._data.shape[-1])
self._transpose = out if self._transpose is None:
return out shape = (data_2d.shape[1], data_2d.shape[0])
self._transpose = torch.empty(shape, dtype=torch.uint8, device=self._data.device)
# Case: recompute transpose and store cache.
if noop_flag is None:
tex.fp8_transpose_noalloc(data_2d, self._transpose, self._fp8_dtype)
else:
# Case: cuda graph capture.
tex.fp8_transpose_noalloc_noop(data_2d, self._transpose, noop_flag, self._fp8_dtype)
self._transpose_invalid = False
return self._transpose
@torch.no_grad() @torch.no_grad()
def reset_fp8_meta_scale_inv(self) -> None: def reset_fp8_meta_scale_inv(self) -> None:
...@@ -519,13 +518,11 @@ class Float8Tensor(torch.Tensor): ...@@ -519,13 +518,11 @@ class Float8Tensor(torch.Tensor):
the tensor. the tensor.
""" """
if self._fp8_meta is None: assert self._fp8_meta is not None, "FP8 meta tensors not found."
return
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=self._fp8_meta_forward, forward=self._fp8_meta_forward,
) )
scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0])
scale_inv.view(1).copy_(self._scale_inv.view(1))
def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: def to_dtype(self, dtype: torch.dtype) -> Float8Tensor:
"""Create `Float8Tensor` with given nominal dtype """Create `Float8Tensor` with given nominal dtype
...@@ -541,12 +538,11 @@ class Float8Tensor(torch.Tensor): ...@@ -541,12 +538,11 @@ class Float8Tensor(torch.Tensor):
) )
def _reset_caches(self) -> None: def _reset_caches(self) -> None:
"""Reset cached values """
Set transpose cache as invalid.
Should be called after any in-place operation. Should be called after any in-place operation.
""" """
self._transpose = None self._transpose_invalid = True
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
...@@ -574,7 +570,7 @@ class Float8Tensor(torch.Tensor): ...@@ -574,7 +570,7 @@ class Float8Tensor(torch.Tensor):
# Directly copy FP8 data if possible # Directly copy FP8 data if possible
if dst._fp8_dtype == src._fp8_dtype: if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data) dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone() dst._scale_inv.copy_(src._scale_inv.detach())
if dst._fp8_meta is not None: if dst._fp8_meta is not None:
if src._fp8_meta is None: if src._fp8_meta is None:
src_min, src_max = src.from_float8().aminmax() src_min, src_max = src.from_float8().aminmax()
...@@ -600,7 +596,6 @@ class Float8Tensor(torch.Tensor): ...@@ -600,7 +596,6 @@ class Float8Tensor(torch.Tensor):
dst.copy_(src.from_float8()) dst.copy_(src.from_float8())
elif dst_is_fp8 and not src_is_fp8: elif dst_is_fp8 and not src_is_fp8:
# Make sure input is in expected format # Make sure input is in expected format
src = src.expand(dst.size()) src = src.expand(dst.size())
src = src.to( src = src.to(
...@@ -619,7 +614,7 @@ class Float8Tensor(torch.Tensor): ...@@ -619,7 +614,7 @@ class Float8Tensor(torch.Tensor):
fp8_meta_index = dst._fp8_meta_index fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv = scale.detach().view(1).reciprocal() dst._scale_inv.copy_(scale.detach().reciprocal())
# Cast to FP8 # Cast to FP8
if not dst._data.is_contiguous(): if not dst._data.is_contiguous():
...@@ -633,6 +628,9 @@ class Float8Tensor(torch.Tensor): ...@@ -633,6 +628,9 @@ class Float8Tensor(torch.Tensor):
dst._fp8_dtype, dst._fp8_dtype,
) )
# This branch is where the FP8 parameters are updated in-place during optimization.
# Handle forward amax reduction.
post_optimizer_step_fwd_amax_reduction(dst)
else: else:
# Invalid case # Invalid case
...@@ -641,6 +639,7 @@ class Float8Tensor(torch.Tensor): ...@@ -641,6 +639,7 @@ class Float8Tensor(torch.Tensor):
# Nothing to return for in-place ops # Nothing to return for in-place ops
if dst_is_fp8: if dst_is_fp8:
dst._reset_caches() dst._reset_caches()
return None return None
# Slice op # Slice op
...@@ -764,6 +763,7 @@ class Float8Tensor(torch.Tensor): ...@@ -764,6 +763,7 @@ class Float8Tensor(torch.Tensor):
_fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index"))
_fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype"))
_transpose = property(**_make_fp8_attr_property_funcs("transpose")) _transpose = property(**_make_fp8_attr_property_funcs("transpose"))
_transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid"))
_scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv"))
# Do not force the Float8Tensor type on the returned tensor # Do not force the Float8Tensor type on the returned tensor
......
This diff is collapsed.
This diff is collapsed.
...@@ -8,8 +8,7 @@ import os ...@@ -8,8 +8,7 @@ import os
import pickle import pickle
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generator, Union, Optional, Tuple, Dict, Any, List from typing import Generator, Union, Optional, Tuple, List
from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
...@@ -22,13 +21,11 @@ from ..fp8 import ( ...@@ -22,13 +21,11 @@ from ..fp8 import (
get_default_fp8_recipe, get_default_fp8_recipe,
get_fp8_te_dtype, get_fp8_te_dtype,
FP8GlobalStateManager, FP8GlobalStateManager,
amax_and_scale_update,
) )
from ..distributed import ( from ..distributed import (
gather_along_first_dim, gather_along_first_dim,
is_fp8_activation_recompute_enabled, is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase, in_fp8_activation_recompute_phase,
get_distributed_world_size,
) )
from ..cpp_extensions import ( from ..cpp_extensions import (
fp8_cast_transpose_fused, fp8_cast_transpose_fused,
...@@ -44,7 +41,6 @@ _2X_ACC_WGRAD = True ...@@ -44,7 +41,6 @@ _2X_ACC_WGRAD = True
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 3 _NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -64,49 +60,6 @@ def get_workspace() -> torch.Tensor: ...@@ -64,49 +60,6 @@ def get_workspace() -> torch.Tensor:
) )
return _cublas_workspace return _cublas_workspace
@contextmanager
def _prepare_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> Generator[None, None, None]:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction
if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1:
# From previous iteration
FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False)
else:
amax_and_scale_update(fp8_meta, False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if (fp8 and fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(fp8_meta["fp8_group"]) > 1):
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False)
def initialize_ub( def initialize_ub(
shape: list, shape: list,
...@@ -300,31 +253,54 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -300,31 +253,54 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_size = 1 self.tp_size = 1
self.sequence_parallel = False self.sequence_parallel = False
self.fp8_weight_shapes = [] self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
)
self.param_init_meta = {} self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
.. warning::
This changes the underlying amax memory location.
"""
if fwd is None:
fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
else:
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
for meta_key in fp8_meta_tensor_keys:
curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len:
continue
if length < curr_len:
self.fp8_meta[meta_key].amax_history = (
self.fp8_meta[meta_key].amax_history[: length].clone())
elif length > curr_len:
extra_rows = length - curr_len
self.fp8_meta[meta_key].amax_history = F.pad(
self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
)
# Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
fwd_pos, fwd_key, bwd_pos, bwd_key = (
self.fp8_meta[FP8GlobalStateManager.get_buffer_info()])
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history[0])
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history)
def set_meta_tensor(self, fwd: bool) -> None: def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
if self.fp8_meta_tensors_initialized: if self.fp8_meta_tensors_initialized:
# Handle changed amax history size. # Handle changed amax history size.
curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0] self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd)
need_len = self.fp8_meta["recipe"].amax_history_len
if need_len < curr_len:
self.fp8_meta[fp8_meta_tensor_key].amax_history = (
self.fp8_meta[fp8_meta_tensor_key]
.amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
)
elif need_len > curr_len:
extra_rows = need_len - curr_len
self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
)
return return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
...@@ -347,25 +323,45 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -347,25 +323,45 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
device="cuda", device="cuda",
) )
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False, True]: -> [input, weight, output]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False, True] * self.fp8_meta["num_gemms"]
).cuda()
else:
# [True, True]: -> [grad_output, grad_input]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None: def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes.""" """Init scales and amaxes."""
self.set_meta_tensor(True) self.set_meta_tensor(True)
self.set_meta_tensor(False) self.set_meta_tensor(False)
self.fp8_meta_tensors_initialized = True self.fp8_meta_tensors_initialized = True
def get_fp8_meta_tensors(self) -> None:
"""Get scales and amaxes."""
fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
return None
fp8_meta_tensors = {fwd_key: [], bwd_key: []}
with torch.no_grad():
for key in (fwd_key, bwd_key):
fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone())
fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
return fp8_meta_tensors
def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
"""Reset scales and amaxes."""
def reset(key):
if key in self.fp8_meta:
if fp8_meta_tensors is None:
self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
self.fp8_meta[key].scale_inv.copy_(
torch.ones_like(self.fp8_meta[key].scale_inv))
self.fp8_meta[key].amax_history.copy_(
torch.zeros_like(self.fp8_meta[key].amax_history))
else:
assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1])
self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2])
with torch.no_grad():
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor: def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing.""" """Save before checkpointing."""
state = None state = None
...@@ -380,13 +376,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -380,13 +376,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint()
state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint()
# Store other pickelable values. # Store other pickelable values.
extra = {} extra = {}
for k, v in self.fp8_meta.items(): for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str, list)): if isinstance(v, (bool, int, float, str, tuple, list)):
extra[k] = v extra[k] = v
state["extra_fp8_variables"] = extra state["extra_fp8_variables"] = extra
...@@ -414,11 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -414,11 +408,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# Restore global FP8 amax buffer.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state.
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
# Load extra items. # Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
...@@ -527,6 +516,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -527,6 +516,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.tp_group = tp_group self.tp_group = tp_group
self.tp_group_initialized = True self.tp_group_initialized = True
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights."""
fp8_params = []
for param in self.parameters():
if isinstance(param, Float8Tensor) and param.requires_grad:
fp8_params.append(param)
if len(fp8_params) == 0:
return None
return fp8_params
# This routine is shared across FP8 and FP8_calibration paths so should not actually # This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution. # assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None: def init_fp8_metadata(self, num_gemms: int = 1) -> None:
...@@ -576,7 +575,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -576,7 +575,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one. just in case. The autocast exit will pick up the most recent one.
""" """
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
...@@ -594,49 +592,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -594,49 +592,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if is_first_microbatch is not None and not self.primary_weights_in_fp8: if is_first_microbatch is not None and not self.primary_weights_in_fp8:
self.set_fp8_weights() self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel: if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \ assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \ "Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8." "necessary when using sequence parallelism with FP8."
# Previous iteration was grad_enabled if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
if self.fp8_meta.get("update_amax_and_scale_fwd", False): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
if (self.fp8_meta["recipe"].reduce_amax self.fp8_meta, fp8_weights=self._get_fp8_params())
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
if (self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.new_fp8_context_id())
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = (
FP8GlobalStateManager.get_fp8_context_id())
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if ( if (
...@@ -653,18 +616,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -653,18 +616,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
return return
if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax
and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1):
FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
FP8GlobalStateManager.global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the before the GEMM for there to be a guaranteed overlap. From the
......
...@@ -14,7 +14,6 @@ from .. import cpp_extensions as tex ...@@ -14,7 +14,6 @@ from .. import cpp_extensions as tex
from .base import ( from .base import (
get_workspace, get_workspace,
_prepare_backward,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -65,6 +64,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -65,6 +64,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_bias: bool, use_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
...@@ -89,6 +89,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -89,6 +89,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_rs_dgrad: bool, ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str, ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -98,7 +99,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -98,7 +99,11 @@ class _LayerNormLinear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
...@@ -196,7 +201,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -196,7 +201,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Weight is already in FP8 # Weight is already in FP8
weight.reset_fp8_meta_scale_inv() weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
weight_fp8 = Float8Tensor( weight_fp8 = Float8Tensor(
...@@ -214,6 +218,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -214,6 +218,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
cast_out=weight_fp8._data, cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data, transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
) )
else: else:
tex.cast_to_fp8( tex.cast_to_fp8(
...@@ -295,6 +300,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -295,6 +300,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_t_fp8, weight_t_fp8,
ln_out if weight.requires_grad else None, ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
...@@ -321,6 +327,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -321,6 +327,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -344,9 +351,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -344,9 +351,7 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward( with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
( (
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -357,6 +362,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -357,6 +362,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_t_fp8, weight_t_fp8,
ln_out, ln_out,
fwd_scale_inverses, fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
...@@ -364,10 +370,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -364,10 +370,13 @@ class _LayerNormLinear(torch.autograd.Function):
weight.main_grad = main_grad weight.main_grad = main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None: if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose( weight_t_fp8 = weight.transpose_2d(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
) )
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False ctx.ub_bulk_dgrad = False
...@@ -472,7 +481,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -472,7 +481,7 @@ class _LayerNormLinear(torch.autograd.Function):
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
weight_t_fp8._data, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -686,6 +695,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -686,6 +695,8 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -970,7 +981,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -970,7 +981,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata() self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
...@@ -990,6 +1000,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -990,6 +1000,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -1084,6 +1098,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1084,6 +1098,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced) produced)
""" """
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
...@@ -1132,6 +1150,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1132,6 +1150,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
skip_fp8_weight_update,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
...@@ -1156,6 +1175,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1156,6 +1175,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_overlap_rs_dgrad, self.ub_overlap_rs_dgrad,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
self.dummy_tensor,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -13,7 +13,6 @@ from torch.nn import init ...@@ -13,7 +13,6 @@ from torch.nn import init
from .base import ( from .base import (
get_workspace, get_workspace,
_prepare_backward,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -94,6 +93,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -94,6 +93,7 @@ class _LayerNormMLP(torch.autograd.Function):
use_fc2_bias: bool, use_fc2_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
...@@ -121,6 +121,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -121,6 +121,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
gemm_gelu_fusion: bool, gemm_gelu_fusion: bool,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -131,7 +132,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -131,7 +132,11 @@ class _LayerNormMLP(torch.autograd.Function):
assert_dim_for_fp8_exec(fc1_weight) assert_dim_for_fp8_exec(fc1_weight)
assert_dim_for_fp8_exec(fc2_weight) assert_dim_for_fp8_exec(fc2_weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
activation_func = _act_func(activation)[0] activation_func = _act_func(activation)[0]
...@@ -225,8 +230,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -225,8 +230,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.reset_fp8_meta_scale_inv() fc2_weight.reset_fp8_meta_scale_inv()
fc1_weight_fp8 = fc1_weight fc1_weight_fp8 = fc1_weight
fc2_weight_fp8 = fc2_weight fc2_weight_fp8 = fc2_weight
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor( fc1_weight_fp8 = Float8Tensor(
...@@ -250,6 +253,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -250,6 +253,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
cast_out=fc1_weight_fp8._data, cast_out=fc1_weight_fp8._data,
transpose_out=fc1_weight_t_fp8._data, transpose_out=fc1_weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
) )
tex.fp8_cast_transpose_fused( tex.fp8_cast_transpose_fused(
fc2_weight, fc2_weight,
...@@ -258,6 +262,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -258,6 +262,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
cast_out=fc2_weight_fp8._data, cast_out=fc2_weight_fp8._data,
transpose_out=fc2_weight_t_fp8._data, transpose_out=fc2_weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
) )
else: else:
tex.cast_to_fp8( tex.cast_to_fp8(
...@@ -510,6 +515,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -510,6 +515,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc1_bias, fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.activation = activation ctx.activation = activation
...@@ -538,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -538,6 +544,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_overlap_ag = ub_overlap_ag ctx.ub_overlap_ag = ub_overlap_ag
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
...@@ -563,9 +570,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -563,9 +570,7 @@ class _LayerNormMLP(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward( with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
):
( (
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -582,6 +587,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -582,6 +587,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight_t_fp8, fc2_weight_t_fp8,
fc1_bias, fc1_bias,
fwd_scale_inverses, fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
...@@ -592,11 +598,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -592,11 +598,18 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_weight.main_grad = fc2_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy" if ctx.primary_weights_in_fp8:
if ctx.fp8 and fc1_weight_t_fp8 is None: fc1_weight_t_fp8 = fc1_weight.transpose_2d(
fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache) cache=ctx.is_first_microbatch is not None,
if ctx.fp8 and fc2_weight_t_fp8 is None: noop_flag=skip_fp8_weight_update,
fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache) )
fc2_weight_t_fp8 = fc2_weight.transpose_2d(
cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
)
elif ctx.fp8:
fc1_weight_t_fp8 = fc1_weight_t_fp8._data
fc2_weight_t_fp8 = fc2_weight_t_fp8._data
activation_func = _act_func(ctx.activation)[1] activation_func = _act_func(ctx.activation)[1]
...@@ -673,7 +686,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -673,7 +686,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm( fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8._data, fc2_weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -826,7 +839,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -826,7 +839,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj = None ub_obj = None
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
fc1_weight_t_fp8._data, fc1_weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -1151,6 +1164,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1151,6 +1164,8 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -1389,7 +1404,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1389,7 +1404,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2) self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
...@@ -1414,6 +1428,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1414,6 +1428,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -1473,7 +1491,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1473,7 +1491,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None self,
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply layer normalization to the input followed by a feedforward network (MLP Block). Apply layer normalization to the input followed by a feedforward network (MLP Block).
...@@ -1497,6 +1517,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1497,6 +1517,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced) produced)
""" """
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
...@@ -1535,6 +1559,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1535,6 +1559,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
skip_fp8_weight_update,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
...@@ -1562,6 +1587,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1562,6 +1587,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.gemm_gelu_fusion, self.gemm_gelu_fusion,
self.dummy_tensor,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -11,7 +11,6 @@ import transformer_engine_extensions as tex ...@@ -11,7 +11,6 @@ import transformer_engine_extensions as tex
from .base import ( from .base import (
get_workspace, get_workspace,
_prepare_backward,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
...@@ -65,6 +64,7 @@ class _Linear(torch.autograd.Function): ...@@ -65,6 +64,7 @@ class _Linear(torch.autograd.Function):
bias: torch.Tensor, bias: torch.Tensor,
use_bias: bool, use_bias: bool,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
skip_fp8_weight_update: Union[torch.Tensor, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
...@@ -80,7 +80,8 @@ class _Linear(torch.autograd.Function): ...@@ -80,7 +80,8 @@ class _Linear(torch.autograd.Function):
primary_weights_in_fp8: bool, primary_weights_in_fp8: bool,
ub_overlap_rs: bool, ub_overlap_rs: bool,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_name: str ub_name: str,
dummy_tensor: torch.Tensor, # pylint: disable=unused-argument
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -90,7 +91,12 @@ class _Linear(torch.autograd.Function): ...@@ -90,7 +91,12 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = (
is_first_microbatch is None
or is_first_microbatch
or skip_fp8_weight_update is not None
)
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
...@@ -140,7 +146,6 @@ class _Linear(torch.autograd.Function): ...@@ -140,7 +146,6 @@ class _Linear(torch.autograd.Function):
# Weight is already in FP8 # Weight is already in FP8
weight.reset_fp8_meta_scale_inv() weight.reset_fp8_meta_scale_inv()
weight_fp8 = weight weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights: elif update_fp8_weights:
# Need to cast weights to FP8 # Need to cast weights to FP8
weight_fp8 = Float8Tensor( weight_fp8 = Float8Tensor(
...@@ -158,6 +163,7 @@ class _Linear(torch.autograd.Function): ...@@ -158,6 +163,7 @@ class _Linear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
cast_out=weight_fp8._data, cast_out=weight_fp8._data,
transpose_out=weight_t_fp8._data, transpose_out=weight_t_fp8._data,
noop_flag=skip_fp8_weight_update,
) )
else: else:
cast_to_fp8( cast_to_fp8(
...@@ -296,6 +302,7 @@ class _Linear(torch.autograd.Function): ...@@ -296,6 +302,7 @@ class _Linear(torch.autograd.Function):
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None, weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
...@@ -313,6 +320,7 @@ class _Linear(torch.autograd.Function): ...@@ -313,6 +320,7 @@ class _Linear(torch.autograd.Function):
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.primary_weights_in_fp8 = primary_weights_in_fp8
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
...@@ -330,9 +338,7 @@ class _Linear(torch.autograd.Function): ...@@ -330,9 +338,7 @@ class _Linear(torch.autograd.Function):
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward( with torch.cuda.nvtx.range("_Linear_backward"):
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
( (
inputmat, inputmat,
inputmat_t, inputmat_t,
...@@ -340,6 +346,7 @@ class _Linear(torch.autograd.Function): ...@@ -340,6 +346,7 @@ class _Linear(torch.autograd.Function):
main_grad, main_grad,
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
skip_fp8_weight_update,
) = ctx.saved_tensors ) = ctx.saved_tensors
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
...@@ -347,10 +354,14 @@ class _Linear(torch.autograd.Function): ...@@ -347,10 +354,14 @@ class _Linear(torch.autograd.Function):
weight.main_grad = main_grad weight.main_grad = main_grad
# Primary weights are in FP8. # Primary weights are in FP8.
if ctx.fp8 and weight_t_fp8 is None: if ctx.primary_weights_in_fp8:
weight_t_fp8 = weight.transpose( weight_t_fp8 = weight.transpose_2d(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", cache=ctx.is_first_microbatch is not None,
noop_flag=skip_fp8_weight_update,
) )
elif ctx.fp8:
weight_t_fp8 = weight_t_fp8._data
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
...@@ -361,6 +372,7 @@ class _Linear(torch.autograd.Function): ...@@ -361,6 +372,7 @@ class _Linear(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else: else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
( (
grad_output, grad_output,
grad_output_c, grad_output_c,
...@@ -401,7 +413,7 @@ class _Linear(torch.autograd.Function): ...@@ -401,7 +413,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
dgrad, _ = fp8_gemm( dgrad, _ = fp8_gemm(
weight_t_fp8._data, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward, fp8_dtype_forward,
...@@ -542,6 +554,8 @@ class _Linear(torch.autograd.Function): ...@@ -542,6 +554,8 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -772,7 +786,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -772,7 +786,6 @@ class Linear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata() self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
...@@ -785,6 +798,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -785,6 +798,10 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init) super().reset_parameters(defer_init=defer_init)
...@@ -858,6 +875,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -858,6 +875,10 @@ class Linear(TransformerEngineBaseModule):
produced) produced)
""" """
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
assert self.fp8 or not self.primary_weights_in_fp8, \ assert self.fp8 or not self.primary_weights_in_fp8, \
"Need to run inside fp8_autocast region when weights are stored in FP8." "Need to run inside fp8_autocast region when weights are stored in FP8."
...@@ -903,6 +924,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -903,6 +924,7 @@ class Linear(TransformerEngineBaseModule):
bias_tensor, bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch, is_first_microbatch,
skip_fp8_weight_update,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
...@@ -919,6 +941,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -919,6 +941,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_overlap_rs, self.ub_overlap_rs,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_name, self.ub_name,
self.dummy_tensor,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
...@@ -473,6 +473,15 @@ class TransformerLayer(torch.nn.Module): ...@@ -473,6 +473,15 @@ class TransformerLayer(torch.nn.Module):
if hasattr(child, "set_tensor_parallel_group"): if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group) child.set_tensor_parallel_group(tp_group)
def reset_fp8_meta_tensors(self) -> None:
"""Set TP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "reset_fp8_meta_tensors"):
child.reset_fp8_meta_tensors()
def set_context_parallel_group( def set_context_parallel_group(
self, self,
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
...@@ -665,7 +674,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -665,7 +674,8 @@ class TransformerLayer(torch.nn.Module):
# MLP. # MLP.
mlp_outputs = self.layernorm_mlp( mlp_outputs = self.layernorm_mlp(
hidden_states, is_first_microbatch=is_first_microbatch hidden_states,
is_first_microbatch=is_first_microbatch,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs mlp_output, mlp_bias, residual = mlp_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment