"vscode:/vscode.git/clone" did not exist on "4ff0203987ff100eaaad69f0a8abf7ed821e3a0a"
Unverified Commit aadd3e7c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Add NVTX to TE modules (#50)



* Add NVTX to TE modules
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix pylint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix NVTX in _prepare_backward
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to C API
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cpplint and link nvToolsExt
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to GeGlu
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent eed1fa26
......@@ -22,7 +22,8 @@ disable=too-many-locals,
attribute-defined-outside-init,
global-statement,
too-many-branches,
global-variable-not-assigned
global-variable-not-assigned,
redefined-argument-from-local
[TYPECHECK]
ignored-modules=torch
......
......@@ -40,9 +40,9 @@ add_library(transformer_engine SHARED
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
find_package(CUDAToolkit REQUIRED cublas)
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
......
......@@ -116,6 +116,7 @@ void dgeglu(const Tensor &grad,
void nvte_gelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......@@ -125,6 +126,7 @@ void nvte_gelu(const NVTETensor input,
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
geglu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......@@ -135,6 +137,7 @@ void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgeglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
......
......@@ -20,6 +20,7 @@
#include <string>
#include <tuple>
#include <vector>
#include "nvtx.h"
namespace transformer_engine {
......@@ -285,6 +286,9 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t);
#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _ ## api_name ## _nvtx_wrapper(#api_name);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
......@@ -1060,12 +1060,13 @@ void nvte_scaled_softmax_forward(
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
NVTE_API_CALL(nvte_scaled_softmax_forward);
using namespace transformer_engine;
scaled_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
}
......@@ -1076,13 +1077,14 @@ void nvte_scaled_softmax_backward(
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(incoming_grads),
*reinterpret_cast<const Tensor*>(softmax_results),
scale_factor,
stream);
NVTE_API_CALL(nvte_scaled_softmax_backward);
using namespace transformer_engine;
scaled_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(incoming_grads),
*reinterpret_cast<const Tensor*>(softmax_results),
scale_factor,
stream);
}
......@@ -1093,13 +1095,14 @@ void nvte_scaled_masked_softmax_forward(
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(mask),
reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
NVTE_API_CALL(nvte_scaled_masked_softmax_forward);
using namespace transformer_engine;
scaled_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(mask),
reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
}
......@@ -1110,11 +1113,12 @@ void nvte_scaled_masked_softmax_backward(
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_masked_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(incoming_grads),
*reinterpret_cast<const Tensor*>(softmax_results),
scale_factor,
stream);
NVTE_API_CALL(nvte_scaled_masked_softmax_backward);
using namespace transformer_engine;
scaled_masked_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(incoming_grads),
*reinterpret_cast<const Tensor*>(softmax_results),
scale_factor,
stream);
}
......@@ -236,6 +236,7 @@ void nvte_cublas_gemm(const NVTETensor A,
bool accumulate,
bool use_split_accumulator,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
......
......@@ -375,6 +375,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
......@@ -403,6 +404,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x),
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_
#define TRANSFORMER_ENGINE_COMMON_NVTX_H_
#include <string>
#include <nvToolsExt.h>
namespace transformer_engine::nvtx {
struct NVTXWrapper {
explicit NVTXWrapper(const std::string &name) {
nvtxRangePush(name.c_str());
}
~NVTXWrapper() {
nvtxRangePop();
}
};
} // namespace transformer_engine::nvtx
#endif // TRANSFORMER_ENGINE_COMMON_NVTX_H_
......@@ -287,11 +287,12 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier));
NVTE_API_CALL(nvte_rmsnorm_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier));
}
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
......@@ -300,10 +301,11 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier));
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier));
}
......@@ -392,6 +392,7 @@ void nvte_cast_transpose(const NVTETensor input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output),
......
......@@ -1566,6 +1566,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output),
......@@ -1582,6 +1583,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input),
......@@ -1597,6 +1599,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine;
dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(geglu_input),
......
......@@ -339,6 +339,7 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine;
std::vector<Tensor*> input_list_,
cast_output_list_, transposed_output_list_;
......
......@@ -308,6 +308,7 @@ void transpose(const Tensor &input,
void nvte_transpose(const NVTETensor input,
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(transposed_output),
......
......@@ -95,6 +95,7 @@ void fp8_dequantize(const Tensor &input,
void nvte_fp8_quantize(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize);
using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......@@ -104,6 +105,7 @@ void nvte_fp8_quantize(const NVTETensor input,
void nvte_fp8_dequantize(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize);
using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
......
......@@ -8,6 +8,7 @@ import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping
from functools import partial
from contextlib import contextmanager
import torch
from torch.nn.parameter import Parameter
......@@ -96,6 +97,43 @@ def get_workspace() -> torch.Tensor:
)
return _cublas_workspace
@contextmanager
def _prepare_backward(fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
name: str = ""):
"""Checks and prep for BWD."""
if fp8:
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
......@@ -323,62 +361,62 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Allocate scales and amaxes
self.init_fp8_meta_tensors()
def pre_forward(self, inp: torch.Tensor, num_gemms: int = 1) -> None:
"""Checks and prep for FWD."""
@contextmanager
def prepare_forward(self, inp: torch.Tensor, num_gemms: int = 1):
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
return inp.contiguous()
assert inp.is_cuda, "TransformerEngine needs CUDA."
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."
if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(self.fp8_meta, True)
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase.
if (
self.fp8
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
amax_and_scale_update(self.fp8_meta, True)
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
return inp.contiguous()
# Activation recomputation is used and this is the first forward phase.
if (
self.fp8
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
def post_forward(self) -> None:
"""This is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent.
"""
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous()
if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta)
......@@ -395,47 +433,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
setup_amax_forward_global_reduce_func(reduce_func)
@staticmethod
def pre_backward(fp8: bool, fp8_meta: Dict[str, Any]) -> None:
"""Checks and prep for BWD."""
if not fp8:
return
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
return
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
@staticmethod
def post_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
) -> None:
"""Checks and prep for BWD."""
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the
......@@ -699,160 +696,156 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight_t_fp8,
ln_out,
fwd_scale_inverses,
) = ctx.saved_tensors
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], ctx.parallel_mode == "row"
)
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormLinear"):
(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight_t_fp8,
ln_out,
fwd_scale_inverses,
) = ctx.saved_tensors
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], ctx.parallel_mode == "row"
)
else:
ln_out_total = ln_out
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
wgrad, _, _ = gemm(
ln_out_total_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
ln_out_total_c = cast_from_fp8(
# WGRAD
wgrad, grad_bias, _ = gemm(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
wgrad, _, _ = gemm(
ln_out_total_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
ln_out_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
# LayerNorm gradient
d_ln_out = dgrad.view(inputmat.shape)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
# LayerNorm gradient
d_ln_out = dgrad.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
if not ctx.use_bias:
grad_bias = None
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
)
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
if not ctx.use_bias:
grad_bias = None
return (
dxmat.view(ctx.inp_shape),
......@@ -1091,33 +1084,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced)
"""
inp = self.pre_forward(inp)
bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight if weight is not None else self.weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
bias_tensor,
self.use_bias,
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
)
self.post_forward()
with self.prepare_forward(inp) as inp: # pylint
bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight if weight is not None else self.weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
bias_tensor,
self.use_bias,
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
)
if self.return_layernorm_output:
out, ln_out = out
......@@ -1274,144 +1264,140 @@ class _Linear(torch.autograd.Function):
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_Linear"):
(
inputmat,
inputmat_t,
weight,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
(
inputmat,
inputmat_t,
weight,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_output, ctx.parallel_mode == "row"
)
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_output, ctx.parallel_mode == "row"
)
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=True
)
else:
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=True
)
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=True
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=True
)
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# DGRAD
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
# DGRAD
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
wgrad = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
wgrad = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
wgrad, _, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
wgrad, _, _ = gemm(
# WGRAD
wgrad, grad_bias, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
if not ctx.use_bias:
grad_bias = None
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
if not ctx.use_bias:
grad_bias = None
return (
wgrad if weight.requires_grad else None,
......@@ -1617,29 +1603,26 @@ class Linear(TransformerEngineBaseModule):
produced)
"""
inp = self.pre_forward(inp)
bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply(
weight if weight is not None else self.weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
inp,
bias_tensor,
self.use_bias,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
)
self.post_forward()
with self.prepare_forward(inp) as inp:
bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply(
weight if weight is not None else self.weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
inp,
bias_tensor,
self.use_bias,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
)
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
......@@ -1871,111 +1854,165 @@ class _LayerNormMLP(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormMLP"):
(
inputmat,
ln_weight,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
) = ctx.saved_tensors
(
inputmat,
ln_weight,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
) = ctx.saved_tensors
(
grad_output,
grad_output_c,
grad_output_t,
fc2_bias_grad,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True
)
(
grad_output,
grad_output_c,
grad_output_t,
fc2_bias_grad,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True
)
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
# FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
# FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = fp8_gemm(
gelu_out_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused(
fc2_dgrad,
fc1_out,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
else:
if fc2_weight.requires_grad:
gelu_out_c = cast_from_fp8(
gelu_out,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc2_wgrad, _, _ = gemm(
gelu_out_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
)
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
fc2_dgrad, fc1_out, fc1_bias
)
# FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = fp8_gemm(
gelu_out_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
dgelu = cast_to_fp8(
dgelu_no_fp8,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
dgelu_t = None
fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused(
fc2_dgrad,
fc1_out,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
# FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT2],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# FC2 DGRAD; Unconditional
fc2_dgrad, _, _ = gemm(
fc2_weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
gelu=not ctx.bias_gelu_nvfusion,
grad=True,
gelu_input=fc1_out,
)
# FC2 WGRAD
if fc2_weight.requires_grad:
gelu_out_c = cast_from_fp8(
fc2_wgrad, fc2_bias_grad, _ = gemm(
gelu_out,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc2_wgrad, _, _ = gemm(
gelu_out_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
......@@ -1984,172 +2021,114 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
fc2_dgrad, fc1_out, fc1_bias
)
dgelu = cast_to_fp8(
dgelu_no_fp8,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
dgelu_t = None
# FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT2],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# FC2 DGRAD; Unconditional
fc2_dgrad, _, _ = gemm(
fc2_weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
gelu=not ctx.bias_gelu_nvfusion,
grad=True,
gelu_input=fc1_out,
)
if ctx.bias_gelu_nvfusion:
fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
else:
dgelu = fc2_dgrad
# FC2 WGRAD
if fc2_weight.requires_grad:
fc2_wgrad, fc2_bias_grad, _ = gemm(
gelu_out,
grad_output,
# FC1 DGRAD: Unconditional
fc1_dgrad, _, _ = gemm(
fc1_weight,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NT",
layout="NN",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if ctx.bias_gelu_nvfusion:
fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
else:
dgelu = fc2_dgrad
# FC1 DGRAD: Unconditional
fc1_dgrad, _, _ = gemm(
fc1_weight,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
handle.wait()
fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
)
elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
if fc1_weight.requires_grad:
if ctx.fp8:
# FC1 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
dgelu_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT2
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
handle.wait()
fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
)
elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
if fc1_weight.requires_grad:
if ctx.fp8:
# FC1 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
dgelu_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT2
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc1_wgrad, _, _ = gemm(
ln_out_total_c,
dgelu_no_fp8,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
)
else:
ln_out_total_c = cast_from_fp8(
# FC1 WGRAD
fc1_wgrad_outputs = gemm(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc1_wgrad, _, _ = gemm(
ln_out_total_c,
dgelu_no_fp8,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
# FC1 WGRAD
fc1_wgrad_outputs = gemm(
ln_out_total,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if ctx.bias_gelu_nvfusion:
fc1_wgrad, _, _ = fc1_wgrad_outputs
else:
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
if ctx.bias_gelu_nvfusion:
fc1_wgrad, _, _ = fc1_wgrad_outputs
else:
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear
if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
handle.wait()
# LayerNorm gradient
d_ln_out = fc1_dgrad.view(inputmat.shape)
# Column Parallel Linear
if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
handle.wait()
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
# LayerNorm gradient
d_ln_out = fc1_dgrad.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
if not ctx.use_bias:
fc2_bias_grad = None
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
)
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
if not ctx.use_bias:
fc2_bias_grad = None
return (
dxmat.view(ctx.inp_shape),
......@@ -2420,36 +2399,33 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced)
"""
inp = self.pre_forward(inp, num_gemms=2)
out = _LayerNormMLP.apply(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.fc1_weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
self.fc1_bias,
self.fc2_weight,
self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None,
self.fc2_bias,
self.use_bias,
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.bias_gelu_nvfusion,
self.set_parallel_mode,
)
self.post_forward()
with self.prepare_forward(inp, num_gemms=2) as inp:
out = _LayerNormMLP.apply(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.fc1_weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
self.fc1_bias,
self.fc2_weight,
self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None,
self.fc2_bias,
self.use_bias,
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.bias_gelu_nvfusion,
self.set_parallel_mode,
)
if self.return_layernorm_output:
out, ln_out = out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment