Unverified Commit aadd3e7c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Add NVTX to TE modules (#50)



* Add NVTX to TE modules
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix pylint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix NVTX in _prepare_backward
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to C API
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cpplint and link nvToolsExt
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to GeGlu
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent eed1fa26
......@@ -22,7 +22,8 @@ disable=too-many-locals,
attribute-defined-outside-init,
global-statement,
too-many-branches,
global-variable-not-assigned
global-variable-not-assigned,
redefined-argument-from-local
[TYPECHECK]
ignored-modules=torch
......
......@@ -40,9 +40,9 @@ add_library(transformer_engine SHARED
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
find_package(CUDAToolkit REQUIRED cublas)
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
......
......@@ -116,6 +116,7 @@ void dgeglu(const Tensor &grad,
void nvte_gelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......@@ -125,6 +126,7 @@ void nvte_gelu(const NVTETensor input,
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
geglu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......@@ -135,6 +137,7 @@ void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgeglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
......
......@@ -20,6 +20,7 @@
#include <string>
#include <tuple>
#include <vector>
#include "nvtx.h"
namespace transformer_engine {
......@@ -285,6 +286,9 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t);
#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _ ## api_name ## _nvtx_wrapper(#api_name);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
......@@ -1060,6 +1060,7 @@ void nvte_scaled_softmax_forward(
float scale_factor,
cudaStream_t stream
) {
NVTE_API_CALL(nvte_scaled_softmax_forward);
using namespace transformer_engine;
scaled_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
......@@ -1076,6 +1077,7 @@ void nvte_scaled_softmax_backward(
float scale_factor,
cudaStream_t stream
) {
NVTE_API_CALL(nvte_scaled_softmax_backward);
using namespace transformer_engine;
scaled_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads),
......@@ -1093,6 +1095,7 @@ void nvte_scaled_masked_softmax_forward(
float scale_factor,
cudaStream_t stream
) {
NVTE_API_CALL(nvte_scaled_masked_softmax_forward);
using namespace transformer_engine;
scaled_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
......@@ -1110,6 +1113,7 @@ void nvte_scaled_masked_softmax_backward(
float scale_factor,
cudaStream_t stream
) {
NVTE_API_CALL(nvte_scaled_masked_softmax_backward);
using namespace transformer_engine;
scaled_masked_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads),
......
......@@ -236,6 +236,7 @@ void nvte_cublas_gemm(const NVTETensor A,
bool accumulate,
bool use_split_accumulator,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
......
......@@ -375,6 +375,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
......@@ -403,6 +404,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x),
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_
#define TRANSFORMER_ENGINE_COMMON_NVTX_H_
#include <string>
#include <nvToolsExt.h>
namespace transformer_engine::nvtx {
struct NVTXWrapper {
explicit NVTXWrapper(const std::string &name) {
nvtxRangePush(name.c_str());
}
~NVTXWrapper() {
nvtxRangePop();
}
};
} // namespace transformer_engine::nvtx
#endif // TRANSFORMER_ENGINE_COMMON_NVTX_H_
......@@ -287,6 +287,7 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
......@@ -300,6 +301,7 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
......
......@@ -392,6 +392,7 @@ void nvte_cast_transpose(const NVTETensor input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output),
......
......@@ -1566,6 +1566,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output),
......@@ -1582,6 +1583,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input),
......@@ -1597,6 +1599,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine;
dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(geglu_input),
......
......@@ -339,6 +339,7 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine;
std::vector<Tensor*> input_list_,
cast_output_list_, transposed_output_list_;
......
......@@ -308,6 +308,7 @@ void transpose(const Tensor &input,
void nvte_transpose(const NVTETensor input,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(transposed_output),
......
......@@ -95,6 +95,7 @@ void fp8_dequantize(const Tensor &input,
void nvte_fp8_quantize(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize);
using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......@@ -104,6 +105,7 @@ void nvte_fp8_quantize(const NVTETensor input,
void nvte_fp8_dequantize(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize);
using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......
......@@ -8,6 +8,7 @@ import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping
from functools import partial
from contextlib import contextmanager
import torch
from torch.nn.parameter import Parameter
......@@ -96,6 +97,43 @@ def get_workspace() -> torch.Tensor:
)
return _cublas_workspace
@contextmanager
def _prepare_backward(fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
name: str = ""):
"""Checks and prep for BWD."""
if fp8:
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
......@@ -323,14 +361,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Allocate scales and amaxes
self.init_fp8_meta_tensors()
def pre_forward(self, inp: torch.Tensor, num_gemms: int = 1) -> None:
"""Checks and prep for FWD."""
@contextmanager
def prepare_forward(self, inp: torch.Tensor, num_gemms: int = 1):
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
return inp.contiguous()
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1:
......@@ -371,14 +415,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
return inp.contiguous()
def post_forward(self) -> None:
"""This is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent.
"""
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous()
if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta)
......@@ -395,47 +433,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
setup_amax_forward_global_reduce_func(reduce_func)
@staticmethod
def pre_backward(fp8: bool, fp8_meta: Dict[str, Any]) -> None:
"""Checks and prep for BWD."""
if not fp8:
return
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
return
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
@staticmethod
def post_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
) -> None:
"""Checks and prep for BWD."""
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the
......@@ -699,8 +696,8 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormLinear"):
(
inputmat,
ln_weight,
......@@ -850,10 +847,6 @@ class _LayerNormLinear(torch.autograd.Function):
if not ctx.use_bias:
grad_bias = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return (
dxmat.view(ctx.inp_shape),
dgamma,
......@@ -1091,8 +1084,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced)
"""
inp = self.pre_forward(inp)
with self.prepare_forward(inp) as inp: # pylint
bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply(
......@@ -1117,8 +1109,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output,
)
self.post_forward()
if self.return_layernorm_output:
out, ln_out = out
......@@ -1274,8 +1264,8 @@ class _Linear(torch.autograd.Function):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_Linear"):
(
inputmat,
inputmat_t,
......@@ -1409,10 +1399,6 @@ class _Linear(torch.autograd.Function):
if not ctx.use_bias:
grad_bias = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return (
wgrad if weight.requires_grad else None,
None,
......@@ -1617,8 +1603,7 @@ class Linear(TransformerEngineBaseModule):
produced)
"""
inp = self.pre_forward(inp)
with self.prepare_forward(inp) as inp:
bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply(
......@@ -1639,8 +1624,6 @@ class Linear(TransformerEngineBaseModule):
self.parallel_mode,
)
self.post_forward()
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
......@@ -1871,8 +1854,8 @@ class _LayerNormMLP(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormMLP"):
(
inputmat,
ln_weight,
......@@ -2147,10 +2130,6 @@ class _LayerNormMLP(torch.autograd.Function):
if not ctx.use_bias:
fc2_bias_grad = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return (
dxmat.view(ctx.inp_shape),
dgamma,
......@@ -2420,8 +2399,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced)
"""
inp = self.pre_forward(inp, num_gemms=2)
with self.prepare_forward(inp, num_gemms=2) as inp:
out = _LayerNormMLP.apply(
inp,
self.layer_norm_weight,
......@@ -2449,8 +2427,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.set_parallel_mode,
)
self.post_forward()
if self.return_layernorm_output:
out, ln_out = out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment