"superbench/vscode:/vscode.git/clone" did not exist on "74421ffee0c3d3c5d055075078b3dcfcfa6d4ae2"
Unverified Commit aadd3e7c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Add NVTX to TE modules (#50)



* Add NVTX to TE modules
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix pylint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix NVTX in _prepare_backward
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to C API
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cpplint and link nvToolsExt
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to GeGlu
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent eed1fa26
...@@ -22,7 +22,8 @@ disable=too-many-locals, ...@@ -22,7 +22,8 @@ disable=too-many-locals,
attribute-defined-outside-init, attribute-defined-outside-init,
global-statement, global-statement,
too-many-branches, too-many-branches,
global-variable-not-assigned global-variable-not-assigned,
redefined-argument-from-local
[TYPECHECK] [TYPECHECK]
ignored-modules=torch ignored-modules=torch
......
...@@ -40,9 +40,9 @@ add_library(transformer_engine SHARED ...@@ -40,9 +40,9 @@ add_library(transformer_engine SHARED
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include") target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
find_package(CUDAToolkit REQUIRED cublas) find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart) list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
......
...@@ -116,6 +116,7 @@ void dgeglu(const Tensor &grad, ...@@ -116,6 +116,7 @@ void dgeglu(const Tensor &grad,
void nvte_gelu(const NVTETensor input, void nvte_gelu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine; using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input), gelu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -125,6 +126,7 @@ void nvte_gelu(const NVTETensor input, ...@@ -125,6 +126,7 @@ void nvte_gelu(const NVTETensor input,
void nvte_geglu(const NVTETensor input, void nvte_geglu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine; using namespace transformer_engine;
geglu_cast(*reinterpret_cast<const Tensor*>(input), geglu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -135,6 +137,7 @@ void nvte_dgeglu(const NVTETensor grad, ...@@ -135,6 +137,7 @@ void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input, const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgeglu(*reinterpret_cast<const Tensor*>(grad), dgeglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "nvtx.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -285,6 +286,9 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -285,6 +286,9 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t); bool is_fp8_dtype(const DType t);
#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _ ## api_name ## _nvtx_wrapper(#api_name);
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
...@@ -1060,6 +1060,7 @@ void nvte_scaled_softmax_forward( ...@@ -1060,6 +1060,7 @@ void nvte_scaled_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
NVTE_API_CALL(nvte_scaled_softmax_forward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_softmax_forward( scaled_softmax_forward(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
...@@ -1076,6 +1077,7 @@ void nvte_scaled_softmax_backward( ...@@ -1076,6 +1077,7 @@ void nvte_scaled_softmax_backward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
NVTE_API_CALL(nvte_scaled_softmax_backward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_softmax_backward( scaled_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads), *reinterpret_cast<Tensor*>(output_grads),
...@@ -1093,6 +1095,7 @@ void nvte_scaled_masked_softmax_forward( ...@@ -1093,6 +1095,7 @@ void nvte_scaled_masked_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
NVTE_API_CALL(nvte_scaled_masked_softmax_forward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_masked_softmax_forward( scaled_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
...@@ -1110,6 +1113,7 @@ void nvte_scaled_masked_softmax_backward( ...@@ -1110,6 +1113,7 @@ void nvte_scaled_masked_softmax_backward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
NVTE_API_CALL(nvte_scaled_masked_softmax_backward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_masked_softmax_backward( scaled_masked_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads), *reinterpret_cast<Tensor*>(output_grads),
......
...@@ -236,6 +236,7 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -236,6 +236,7 @@ void nvte_cublas_gemm(const NVTETensor A,
bool accumulate, bool accumulate,
bool use_split_accumulator, bool use_split_accumulator,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A); const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B); const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
......
...@@ -375,6 +375,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -375,6 +375,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const int multiprocessorCount, const int multiprocessorCount,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma), *reinterpret_cast<const Tensor*>(gamma),
...@@ -403,6 +404,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size ...@@ -403,6 +404,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const int multiprocessorCount, const int multiprocessorCount,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(x),
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_
#define TRANSFORMER_ENGINE_COMMON_NVTX_H_
#include <string>
#include <nvToolsExt.h>
namespace transformer_engine::nvtx {
struct NVTXWrapper {
explicit NVTXWrapper(const std::string &name) {
nvtxRangePush(name.c_str());
}
~NVTXWrapper() {
nvtxRangePop();
}
};
} // namespace transformer_engine::nvtx
#endif // TRANSFORMER_ENGINE_COMMON_NVTX_H_
...@@ -287,6 +287,7 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size ...@@ -287,6 +287,7 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma), rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream, epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
...@@ -300,6 +301,7 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size ...@@ -300,6 +301,7 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine; using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x), rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma), *reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
......
...@@ -392,6 +392,7 @@ void nvte_cast_transpose(const NVTETensor input, ...@@ -392,6 +392,7 @@ void nvte_cast_transpose(const NVTETensor input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
......
...@@ -1566,6 +1566,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input, ...@@ -1566,6 +1566,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input), cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1582,6 +1583,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -1582,6 +1583,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input), cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input), *reinterpret_cast<const Tensor*>(gelu_input),
...@@ -1597,6 +1599,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, ...@@ -1597,6 +1599,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input), dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(geglu_input), *reinterpret_cast<const Tensor*>(geglu_input),
......
...@@ -339,6 +339,7 @@ void nvte_multi_cast_transpose(size_t num_tensors, ...@@ -339,6 +339,7 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* cast_output_list, NVTETensor* cast_output_list,
NVTETensor* transposed_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
std::vector<Tensor*> input_list_, std::vector<Tensor*> input_list_,
cast_output_list_, transposed_output_list_; cast_output_list_, transposed_output_list_;
......
...@@ -308,6 +308,7 @@ void transpose(const Tensor &input, ...@@ -308,6 +308,7 @@ void transpose(const Tensor &input,
void nvte_transpose(const NVTETensor input, void nvte_transpose(const NVTETensor input,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input), transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
......
...@@ -95,6 +95,7 @@ void fp8_dequantize(const Tensor &input, ...@@ -95,6 +95,7 @@ void fp8_dequantize(const Tensor &input,
void nvte_fp8_quantize(const NVTETensor input, void nvte_fp8_quantize(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input), fp8_quantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -104,6 +105,7 @@ void nvte_fp8_quantize(const NVTETensor input, ...@@ -104,6 +105,7 @@ void nvte_fp8_quantize(const NVTETensor input,
void nvte_fp8_dequantize(const NVTETensor input, void nvte_fp8_dequantize(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input), fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
......
...@@ -8,6 +8,7 @@ import warnings ...@@ -8,6 +8,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping
from functools import partial from functools import partial
from contextlib import contextmanager
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -96,6 +97,43 @@ def get_workspace() -> torch.Tensor: ...@@ -96,6 +97,43 @@ def get_workspace() -> torch.Tensor:
) )
return _cublas_workspace return _cublas_workspace
@contextmanager
def _prepare_backward(fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
name: str = ""):
"""Checks and prep for BWD."""
if fp8:
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
class TransformerEngineBaseModule(torch.nn.Module, ABC): class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module.""" """Base TE module."""
...@@ -323,14 +361,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -323,14 +361,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Allocate scales and amaxes # Allocate scales and amaxes
self.init_fp8_meta_tensors() self.init_fp8_meta_tensors()
def pre_forward(self, inp: torch.Tensor, num_gemms: int = 1) -> None: @contextmanager
"""Checks and prep for FWD.""" def prepare_forward(self, inp: torch.Tensor, num_gemms: int = 1):
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
return inp.contiguous() else:
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1: if self.tp_size > 1:
...@@ -371,14 +415,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -371,14 +415,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
): ):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
return inp.contiguous() with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous()
def post_forward(self) -> None:
"""This is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent.
"""
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta) restore_fp8_meta_tensors(self.fp8_meta)
...@@ -395,47 +433,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -395,47 +433,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
setup_amax_forward_global_reduce_func(reduce_func) setup_amax_forward_global_reduce_func(reduce_func)
@staticmethod
def pre_backward(fp8: bool, fp8_meta: Dict[str, Any]) -> None:
"""Checks and prep for BWD."""
if not fp8:
return
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
return
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
@staticmethod
def post_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
) -> None:
"""Checks and prep for BWD."""
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the before the GEMM for there to be a guaranteed overlap. From the
...@@ -699,8 +696,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -699,8 +696,8 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta) with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormLinear"):
( (
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -850,10 +847,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -850,10 +847,6 @@ class _LayerNormLinear(torch.autograd.Function):
if not ctx.use_bias: if not ctx.use_bias:
grad_bias = None grad_bias = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return ( return (
dxmat.view(ctx.inp_shape), dxmat.view(ctx.inp_shape),
dgamma, dgamma,
...@@ -1091,8 +1084,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1091,8 +1084,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced) produced)
""" """
inp = self.pre_forward(inp) with self.prepare_forward(inp) as inp: # pylint
bias_tensor = bias if bias is not None else self.bias bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply( out = _LayerNormLinear.apply(
...@@ -1117,8 +1109,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1117,8 +1109,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output, self.return_layernorm_output,
) )
self.post_forward()
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1274,8 +1264,8 @@ class _Linear(torch.autograd.Function): ...@@ -1274,8 +1264,8 @@ class _Linear(torch.autograd.Function):
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta) with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_Linear"):
( (
inputmat, inputmat,
inputmat_t, inputmat_t,
...@@ -1409,10 +1399,6 @@ class _Linear(torch.autograd.Function): ...@@ -1409,10 +1399,6 @@ class _Linear(torch.autograd.Function):
if not ctx.use_bias: if not ctx.use_bias:
grad_bias = None grad_bias = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return ( return (
wgrad if weight.requires_grad else None, wgrad if weight.requires_grad else None,
None, None,
...@@ -1617,8 +1603,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1617,8 +1603,7 @@ class Linear(TransformerEngineBaseModule):
produced) produced)
""" """
inp = self.pre_forward(inp) with self.prepare_forward(inp) as inp:
bias_tensor = bias if bias is not None else self.bias bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply( out = _Linear.apply(
...@@ -1639,8 +1624,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1639,8 +1624,6 @@ class Linear(TransformerEngineBaseModule):
self.parallel_mode, self.parallel_mode,
) )
self.post_forward()
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype) out = out + cast_if_needed(bias_tensor, self.activation_dtype)
...@@ -1871,8 +1854,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1871,8 +1854,8 @@ class _LayerNormMLP(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta) with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormMLP"):
( (
inputmat, inputmat,
ln_weight, ln_weight,
...@@ -2147,10 +2130,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2147,10 +2130,6 @@ class _LayerNormMLP(torch.autograd.Function):
if not ctx.use_bias: if not ctx.use_bias:
fc2_bias_grad = None fc2_bias_grad = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return ( return (
dxmat.view(ctx.inp_shape), dxmat.view(ctx.inp_shape),
dgamma, dgamma,
...@@ -2420,8 +2399,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2420,8 +2399,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced) produced)
""" """
inp = self.pre_forward(inp, num_gemms=2) with self.prepare_forward(inp, num_gemms=2) as inp:
out = _LayerNormMLP.apply( out = _LayerNormMLP.apply(
inp, inp,
self.layer_norm_weight, self.layer_norm_weight,
...@@ -2449,8 +2427,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2449,8 +2427,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.set_parallel_mode, self.set_parallel_mode,
) )
self.post_forward()
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment