"vscode:/vscode.git/clone" did not exist on "ff760a9d838ae4617600cccb22131d0359ce0296"
Unverified Commit aadd3e7c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Add NVTX to TE modules (#50)



* Add NVTX to TE modules
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix pylint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix NVTX in _prepare_backward
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to C API
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cpplint and link nvToolsExt
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to GeGlu
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent eed1fa26
...@@ -22,7 +22,8 @@ disable=too-many-locals, ...@@ -22,7 +22,8 @@ disable=too-many-locals,
attribute-defined-outside-init, attribute-defined-outside-init,
global-statement, global-statement,
too-many-branches, too-many-branches,
global-variable-not-assigned global-variable-not-assigned,
redefined-argument-from-local
[TYPECHECK] [TYPECHECK]
ignored-modules=torch ignored-modules=torch
......
...@@ -40,9 +40,9 @@ add_library(transformer_engine SHARED ...@@ -40,9 +40,9 @@ add_library(transformer_engine SHARED
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include") target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
find_package(CUDAToolkit REQUIRED cublas) find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart) list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
......
...@@ -116,6 +116,7 @@ void dgeglu(const Tensor &grad, ...@@ -116,6 +116,7 @@ void dgeglu(const Tensor &grad,
void nvte_gelu(const NVTETensor input, void nvte_gelu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine; using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input), gelu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -125,6 +126,7 @@ void nvte_gelu(const NVTETensor input, ...@@ -125,6 +126,7 @@ void nvte_gelu(const NVTETensor input,
void nvte_geglu(const NVTETensor input, void nvte_geglu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine; using namespace transformer_engine;
geglu_cast(*reinterpret_cast<const Tensor*>(input), geglu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -135,6 +137,7 @@ void nvte_dgeglu(const NVTETensor grad, ...@@ -135,6 +137,7 @@ void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input, const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgeglu(*reinterpret_cast<const Tensor*>(grad), dgeglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "nvtx.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -285,6 +286,9 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -285,6 +286,9 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t); bool is_fp8_dtype(const DType t);
#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _ ## api_name ## _nvtx_wrapper(#api_name);
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
...@@ -1060,12 +1060,13 @@ void nvte_scaled_softmax_forward( ...@@ -1060,12 +1060,13 @@ void nvte_scaled_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; NVTE_API_CALL(nvte_scaled_softmax_forward);
scaled_softmax_forward( using namespace transformer_engine;
*reinterpret_cast<const Tensor*>(input), scaled_softmax_forward(
reinterpret_cast<Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(input),
scale_factor, reinterpret_cast<Tensor*>(softmax_results),
stream); scale_factor,
stream);
} }
...@@ -1076,13 +1077,14 @@ void nvte_scaled_softmax_backward( ...@@ -1076,13 +1077,14 @@ void nvte_scaled_softmax_backward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; NVTE_API_CALL(nvte_scaled_softmax_backward);
scaled_softmax_backward( using namespace transformer_engine;
*reinterpret_cast<Tensor*>(output_grads), scaled_softmax_backward(
*reinterpret_cast<const Tensor*>(incoming_grads), *reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(incoming_grads),
scale_factor, *reinterpret_cast<const Tensor*>(softmax_results),
stream); scale_factor,
stream);
} }
...@@ -1093,13 +1095,14 @@ void nvte_scaled_masked_softmax_forward( ...@@ -1093,13 +1095,14 @@ void nvte_scaled_masked_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; NVTE_API_CALL(nvte_scaled_masked_softmax_forward);
scaled_masked_softmax_forward( using namespace transformer_engine;
*reinterpret_cast<const Tensor*>(input), scaled_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(mask), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(mask),
scale_factor, reinterpret_cast<Tensor*>(softmax_results),
stream); scale_factor,
stream);
} }
...@@ -1110,11 +1113,12 @@ void nvte_scaled_masked_softmax_backward( ...@@ -1110,11 +1113,12 @@ void nvte_scaled_masked_softmax_backward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; NVTE_API_CALL(nvte_scaled_masked_softmax_backward);
scaled_masked_softmax_backward( using namespace transformer_engine;
*reinterpret_cast<Tensor*>(output_grads), scaled_masked_softmax_backward(
*reinterpret_cast<const Tensor*>(incoming_grads), *reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(incoming_grads),
scale_factor, *reinterpret_cast<const Tensor*>(softmax_results),
stream); scale_factor,
stream);
} }
...@@ -236,6 +236,7 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -236,6 +236,7 @@ void nvte_cublas_gemm(const NVTETensor A,
bool accumulate, bool accumulate,
bool use_split_accumulator, bool use_split_accumulator,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A); const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B); const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
......
...@@ -375,6 +375,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -375,6 +375,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const int multiprocessorCount, const int multiprocessorCount,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma), *reinterpret_cast<const Tensor*>(gamma),
...@@ -403,6 +404,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size ...@@ -403,6 +404,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const int multiprocessorCount, const int multiprocessorCount,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(x),
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_
#define TRANSFORMER_ENGINE_COMMON_NVTX_H_
#include <string>
#include <nvToolsExt.h>
namespace transformer_engine::nvtx {
struct NVTXWrapper {
explicit NVTXWrapper(const std::string &name) {
nvtxRangePush(name.c_str());
}
~NVTXWrapper() {
nvtxRangePop();
}
};
} // namespace transformer_engine::nvtx
#endif // TRANSFORMER_ENGINE_COMMON_NVTX_H_
...@@ -287,11 +287,12 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size ...@@ -287,11 +287,12 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
using namespace transformer_engine; NVTE_API_CALL(nvte_rmsnorm_fwd);
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma), using namespace transformer_engine;
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream, rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
multiprocessorCount, reinterpret_cast<Tensor *>(workspace), epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
reinterpret_cast<Tensor *>(barrier)); multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier));
} }
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
...@@ -300,10 +301,11 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size ...@@ -300,10 +301,11 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
using namespace transformer_engine; NVTE_API_CALL(nvte_rmsnorm_bwd);
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x), using namespace transformer_engine;
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma), rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma), *reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount, reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier)); reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier));
} }
...@@ -392,6 +392,7 @@ void nvte_cast_transpose(const NVTETensor input, ...@@ -392,6 +392,7 @@ void nvte_cast_transpose(const NVTETensor input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
......
...@@ -1566,6 +1566,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input, ...@@ -1566,6 +1566,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input), cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1582,6 +1583,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -1582,6 +1583,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input), cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input), *reinterpret_cast<const Tensor*>(gelu_input),
...@@ -1597,6 +1599,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, ...@@ -1597,6 +1599,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input), dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(geglu_input), *reinterpret_cast<const Tensor*>(geglu_input),
......
...@@ -339,6 +339,7 @@ void nvte_multi_cast_transpose(size_t num_tensors, ...@@ -339,6 +339,7 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* cast_output_list, NVTETensor* cast_output_list,
NVTETensor* transposed_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
std::vector<Tensor*> input_list_, std::vector<Tensor*> input_list_,
cast_output_list_, transposed_output_list_; cast_output_list_, transposed_output_list_;
......
...@@ -308,6 +308,7 @@ void transpose(const Tensor &input, ...@@ -308,6 +308,7 @@ void transpose(const Tensor &input,
void nvte_transpose(const NVTETensor input, void nvte_transpose(const NVTETensor input,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input), transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
......
...@@ -95,6 +95,7 @@ void fp8_dequantize(const Tensor &input, ...@@ -95,6 +95,7 @@ void fp8_dequantize(const Tensor &input,
void nvte_fp8_quantize(const NVTETensor input, void nvte_fp8_quantize(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input), fp8_quantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -104,6 +105,7 @@ void nvte_fp8_quantize(const NVTETensor input, ...@@ -104,6 +105,7 @@ void nvte_fp8_quantize(const NVTETensor input,
void nvte_fp8_dequantize(const NVTETensor input, void nvte_fp8_dequantize(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input), fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
......
...@@ -8,6 +8,7 @@ import warnings ...@@ -8,6 +8,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping
from functools import partial from functools import partial
from contextlib import contextmanager
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -96,6 +97,43 @@ def get_workspace() -> torch.Tensor: ...@@ -96,6 +97,43 @@ def get_workspace() -> torch.Tensor:
) )
return _cublas_workspace return _cublas_workspace
@contextmanager
def _prepare_backward(fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
name: str = ""):
"""Checks and prep for BWD."""
if fp8:
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
class TransformerEngineBaseModule(torch.nn.Module, ABC): class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module.""" """Base TE module."""
...@@ -323,62 +361,62 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -323,62 +361,62 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Allocate scales and amaxes # Allocate scales and amaxes
self.init_fp8_meta_tensors() self.init_fp8_meta_tensors()
def pre_forward(self, inp: torch.Tensor, num_gemms: int = 1) -> None: @contextmanager
"""Checks and prep for FWD.""" def prepare_forward(self, inp: torch.Tensor, num_gemms: int = 1):
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
return inp.contiguous() else:
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1: if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized." assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp) self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms) self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights() self.set_fp8_weights()
# Previous iteration was grad_enabled # Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True) copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True) amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True) set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(self.fp8_meta, True)
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else: else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id() amax_and_scale_update(self.fp8_meta, True)
add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True if self.fp8 and self.training:
else: # Setup for amax reduction
self.fp8_meta["update_amax_and_scale_fwd"] = False if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
# Activation recomputation is used and this is the first forward phase. if self.fp8_meta["first_module"]:
if ( self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
self.fp8 set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
and is_fp8_activation_recompute_enabled() else:
and not in_fp8_activation_recompute_phase() self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
): add_amax_to_global_buffer(self.fp8_meta, forward=True)
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
return inp.contiguous() # Activation recomputation is used and this is the first forward phase.
if (
self.fp8
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
def post_forward(self) -> None: with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
"""This is needed because there isn't a way for a module to know yield inp.contiguous()
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent.
"""
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta) restore_fp8_meta_tensors(self.fp8_meta)
...@@ -395,47 +433,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -395,47 +433,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
setup_amax_forward_global_reduce_func(reduce_func) setup_amax_forward_global_reduce_func(reduce_func)
@staticmethod
def pre_backward(fp8: bool, fp8_meta: Dict[str, Any]) -> None:
"""Checks and prep for BWD."""
if not fp8:
return
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
return
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
@staticmethod
def post_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
) -> None:
"""Checks and prep for BWD."""
if not fp8 or not fp8_meta["recipe"].reduce_amax:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled """When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the before the GEMM for there to be a guaranteed overlap. From the
...@@ -699,160 +696,156 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -699,160 +696,156 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta) with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormLinear"):
( (
inputmat, inputmat,
ln_weight, ln_weight,
mu, mu,
rsigma, rsigma,
weight, weight,
weight_t_fp8, weight_t_fp8,
ln_out, ln_out,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], ctx.parallel_mode == "row"
)
# Column Parallel Linear (
# Overlap input AG with dgrad grad_output,
if ctx.parallel_mode == "column" and ctx.sequence_parallel: grad_output_c,
ln_out_total, handle = gather_along_first_dim( grad_output_t,
ln_out, ctx.tp_group, async_op=True grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], ctx.parallel_mode == "row"
) )
else:
ln_out_total = ln_out
if ctx.is_first_microbatch is not None: # Column Parallel Linear
accumulate_wgrad_into_param_main_grad = ( # Overlap input AG with dgrad
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch if ctx.parallel_mode == "column" and ctx.sequence_parallel:
) ln_out_total, handle = gather_along_first_dim(
else: ln_out, ctx.tp_group, async_op=True
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation )
else:
ln_out_total = ln_out
if ctx.fp8: if ctx.is_first_microbatch is not None:
fp8_dtype_forward = get_fp8_te_dtype( accumulate_wgrad_into_param_main_grad = (
ctx.fp8_meta["recipe"], fprop_tensor=True ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
) )
fp8_dtype_backward = get_fp8_te_dtype( else:
ctx.fp8_meta["recipe"], fprop_tensor=False accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
)
# DGRAD: Evaluated unconditionally to feed into Linear backward if ctx.fp8:
dgrad = fp8_gemm( fp8_dtype_forward = get_fp8_te_dtype(
weight_t_fp8, ctx.fp8_meta["recipe"], fprop_tensor=True
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], )
fp8_dtype_forward, fp8_dtype_backward = get_fp8_te_dtype(
grad_output_c, ctx.fp8_meta["recipe"], fprop_tensor=False
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1], )
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# Overlap dgrad-RS/AR with wgrad # DGRAD: Evaluated unconditionally to feed into Linear backward
if ctx.parallel_mode == "column" and ctx.sequence_parallel: dgrad = fp8_gemm(
handle.wait() weight_t_fp8,
dgrad, handle = reduce_scatter_along_first_dim( fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
dgrad, ctx.tp_group, async_op=True fp8_dtype_forward,
) grad_output_c,
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
if weight.requires_grad: # Overlap dgrad-RS/AR with wgrad
if ctx.fp8: if ctx.parallel_mode == "column" and ctx.sequence_parallel:
# WGRAD handle.wait()
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: dgrad, handle = reduce_scatter_along_first_dim(
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) dgrad, ctx.tp_group, async_op=True
wgrad = fp8_gemm( )
ln_out_total_t, elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
fp8_dtype_forward,
grad_output_t, if weight.requires_grad:
ctx.fp8_meta["scaling_bwd"].scale_inv[ if ctx.fp8:
tex.FP8BwdTensors.GRAD_OUTPUT1 # WGRAD
], if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
fp8_dtype_backward, ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
ctx.activation_dtype, wgrad = fp8_gemm(
get_workspace(), ln_out_total_t,
accumulate=accumulate_wgrad_into_param_main_grad, fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp32_output=ctx.fuse_wgrad_accumulation, fp8_dtype_forward,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, grad_output_t,
use_split_accumulator=_2X_ACC_WGRAD, ctx.fp8_meta["scaling_bwd"].scale_inv[
) tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
wgrad, _, _ = gemm(
ln_out_total_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else: else:
ln_out_total_c = cast_from_fp8( # WGRAD
wgrad, grad_bias, _ = gemm(
ln_out_total, ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
wgrad, _, _ = gemm(
ln_out_total_c,
grad_output, grad_output,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation, fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
ln_out_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
# Column Parallel Linear # Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait() handle.wait()
# LayerNorm gradient
d_ln_out = dgrad.view(inputmat.shape)
# Residual gradient # LayerNorm gradient
if ctx.return_layernorm_output: d_ln_out = dgrad.view(inputmat.shape)
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd( # Residual gradient
d_ln_out, inputmat, mu, rsigma, ln_weight if ctx.return_layernorm_output:
) d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
if not ctx.use_bias: dxmat, dgamma, dbeta = tex.layernorm_bwd(
grad_bias = None d_ln_out, inputmat, mu, rsigma, ln_weight
)
TransformerEngineBaseModule.post_backward( if not ctx.use_bias:
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group grad_bias = None
)
return ( return (
dxmat.view(ctx.inp_shape), dxmat.view(ctx.inp_shape),
...@@ -1091,33 +1084,30 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1091,33 +1084,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced) produced)
""" """
inp = self.pre_forward(inp) with self.prepare_forward(inp) as inp: # pylint
bias_tensor = bias if bias is not None else self.bias
bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply(
out = _LayerNormLinear.apply( inp,
inp, self.layer_norm_weight,
self.layer_norm_weight, self.layer_norm_bias,
self.layer_norm_bias, weight if weight is not None else self.weight,
weight if weight is not None else self.weight, self.weight1_fp8 if self.fp8 else None,
self.weight1_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None, bias_tensor,
bias_tensor, self.use_bias,
self.use_bias, self.eps,
self.eps, is_first_microbatch,
is_first_microbatch, self.fp8,
self.fp8, self.fp8_meta,
self.fp8_meta, self.fuse_wgrad_accumulation,
self.fuse_wgrad_accumulation, self.tp_group,
self.tp_group, self.sequence_parallel,
self.sequence_parallel, self.tp_size > 1,
self.tp_size > 1, self.activation_dtype,
self.activation_dtype, self.parallel_mode,
self.parallel_mode, self.return_layernorm_output,
self.return_layernorm_output, )
)
self.post_forward()
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1274,144 +1264,140 @@ class _Linear(torch.autograd.Function): ...@@ -1274,144 +1264,140 @@ class _Linear(torch.autograd.Function):
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta) with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_Linear"):
(
inputmat,
inputmat_t,
weight,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
( (
inputmat, grad_output,
inputmat_t, grad_output_c,
weight, grad_output_t,
weight_t_fp8, grad_bias,
fwd_scale_inverses, ) = TransformerEngineBaseModule.grad_output_preprocess(
) = ctx.saved_tensors ctx, grad_output, ctx.parallel_mode == "row"
)
( # Column Parallel Linear
grad_output, # Overlap input AG with dgrad
grad_output_c, if ctx.parallel_mode == "column" and ctx.sequence_parallel:
grad_output_t, if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
grad_bias, inputmat_t_total, handle = gather_along_last_dim(
) = TransformerEngineBaseModule.grad_output_preprocess( inputmat_t, ctx.tp_group, async_op=True
ctx, grad_output, ctx.parallel_mode == "row" )
) else:
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=True
)
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
# Column Parallel Linear if ctx.is_first_microbatch is not None:
# Overlap input AG with dgrad accumulate_wgrad_into_param_main_grad = (
if ctx.parallel_mode == "column" and ctx.sequence_parallel: ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=True
) )
else: else:
inputmat_total, handle = gather_along_first_dim( accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
inputmat, ctx.tp_group, async_op=True
)
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fp8: if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype( fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True ctx.fp8_meta["recipe"], fprop_tensor=True
) )
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
# DGRAD
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# Overlap dgrad-RS/AR with wgrad # DGRAD
if ctx.parallel_mode == "column" and ctx.sequence_parallel: dgrad = fp8_gemm(
handle.wait() weight_t_fp8,
dgrad, handle = reduce_scatter_along_first_dim( fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
dgrad, ctx.tp_group, async_op=True fp8_dtype_forward,
) grad_output_c,
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
if weight.requires_grad: # Overlap dgrad-RS/AR with wgrad
if ctx.fp8: if ctx.parallel_mode == "column" and ctx.sequence_parallel:
# WGRAD handle.wait()
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: dgrad, handle = reduce_scatter_along_first_dim(
wgrad = fp8_gemm( dgrad, ctx.tp_group, async_op=True
inputmat_t_total, )
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
fp8_dtype_forward, dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[ if weight.requires_grad:
tex.FP8BwdTensors.GRAD_OUTPUT1 if ctx.fp8:
], # WGRAD
fp8_dtype_backward, if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ctx.activation_dtype, wgrad = fp8_gemm(
get_workspace(), inputmat_t_total,
accumulate=accumulate_wgrad_into_param_main_grad, fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp32_output=ctx.fuse_wgrad_accumulation, fp8_dtype_forward,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, grad_output_t,
use_split_accumulator=_2X_ACC_WGRAD, ctx.fp8_meta["scaling_bwd"].scale_inv[
) tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
wgrad, _, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else: else:
wgrad, _, _ = gemm( # WGRAD
wgrad, grad_bias, _ = gemm(
inputmat_total, inputmat_total,
grad_output, grad_output,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation, fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
if not ctx.use_bias: # Column Parallel Linear
grad_bias = None if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
TransformerEngineBaseModule.post_backward( if not ctx.use_bias:
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group grad_bias = None
)
return ( return (
wgrad if weight.requires_grad else None, wgrad if weight.requires_grad else None,
...@@ -1617,29 +1603,26 @@ class Linear(TransformerEngineBaseModule): ...@@ -1617,29 +1603,26 @@ class Linear(TransformerEngineBaseModule):
produced) produced)
""" """
inp = self.pre_forward(inp) with self.prepare_forward(inp) as inp:
bias_tensor = bias if bias is not None else self.bias
bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply(
out = _Linear.apply( weight if weight is not None else self.weight,
weight if weight is not None else self.weight, self.weight1_fp8 if self.fp8 else None,
self.weight1_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None, inp,
inp, bias_tensor,
bias_tensor, self.use_bias,
self.use_bias, is_first_microbatch,
is_first_microbatch, self.fp8,
self.fp8, self.fp8_meta,
self.fp8_meta, self.fuse_wgrad_accumulation,
self.fuse_wgrad_accumulation, self.tp_group,
self.tp_group, self.sequence_parallel,
self.sequence_parallel, self.tp_size > 1,
self.tp_size > 1, self.activation_dtype,
self.activation_dtype, self.parallel_mode,
self.parallel_mode, )
)
self.post_forward()
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype) out = out + cast_if_needed(bias_tensor, self.activation_dtype)
...@@ -1871,111 +1854,165 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1871,111 +1854,165 @@ class _LayerNormMLP(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta) with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
name="_LayerNormMLP"):
(
inputmat,
ln_weight,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
) = ctx.saved_tensors
( (
inputmat, grad_output,
ln_weight, grad_output_c,
mu, grad_output_t,
rsigma, fc2_bias_grad,
ln_out, ) = TransformerEngineBaseModule.grad_output_preprocess(
fc1_out, ctx, grad_outputs[0], True
gelu_out, )
fc1_weight,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
) = ctx.saved_tensors
( # Column Parallel Linear
grad_output, # Overlap input AG with dgrad
grad_output_c, if ctx.set_parallel_mode and ctx.sequence_parallel:
grad_output_t, ln_out_total, handle = gather_along_first_dim(
fc2_bias_grad, ln_out, ctx.tp_group, async_op=True
) = TransformerEngineBaseModule.grad_output_preprocess( )
ctx, grad_outputs[0], True else:
) ln_out_total = ln_out
# Column Parallel Linear if ctx.is_first_microbatch is not None:
# Overlap input AG with dgrad accumulate_wgrad_into_param_main_grad = (
if ctx.set_parallel_mode and ctx.sequence_parallel: ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
ln_out_total, handle = gather_along_first_dim( )
ln_out, ctx.tp_group, async_op=True else:
) accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
else:
ln_out_total = ln_out
if ctx.is_first_microbatch is not None: if ctx.fp8:
accumulate_wgrad_into_param_main_grad = ( fp8_dtype_forward = get_fp8_te_dtype(
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ctx.fp8_meta["recipe"], fprop_tensor=True
) )
else: fp8_dtype_backward = get_fp8_te_dtype(
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation ctx.fp8_meta["recipe"], fprop_tensor=False
)
if ctx.fp8: # FC2 DGRAD; Unconditional
fp8_dtype_forward = get_fp8_te_dtype( fc2_dgrad = fp8_gemm(
ctx.fp8_meta["recipe"], fprop_tensor=True fc2_weight_t_fp8,
) fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT],
fp8_dtype_backward = get_fp8_te_dtype( fp8_dtype_forward,
ctx.fp8_meta["recipe"], fprop_tensor=False grad_output_c,
) ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
# FC2 DGRAD; Unconditional # FC2 WGRAD
fc2_dgrad = fp8_gemm( if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
fc2_weight_t_fp8, if fc2_weight.requires_grad:
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT], gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fp8_dtype_forward, fc2_wgrad = fp8_gemm(
grad_output_c, gelu_out_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1], fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT],
fp8_dtype_backward, fp8_dtype_forward,
ctx.activation_dtype, grad_output_t,
get_workspace(), ctx.fp8_meta["scaling_bwd"].scale_inv[
use_split_accumulator=_2X_ACC_DGRAD, tex.FP8BwdTensors.GRAD_OUTPUT1
) ],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused(
fc2_dgrad,
fc1_out,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
else:
if fc2_weight.requires_grad:
gelu_out_c = cast_from_fp8(
gelu_out,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc2_wgrad, _, _ = gemm(
gelu_out_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
)
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
fc2_dgrad, fc1_out, fc1_bias
)
# FC2 WGRAD dgelu = cast_to_fp8(
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: dgelu_no_fp8,
if fc2_weight.requires_grad: ctx.fp8_meta["scaling_bwd"],
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) tex.FP8BwdTensors.GRAD_OUTPUT2,
fc2_wgrad = fp8_gemm(
gelu_out_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
) )
dgelu_t = None
fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused( # FC1 DGRAD: Unconditional
fc2_dgrad, fc1_dgrad = fp8_gemm(
fc1_out, fc1_weight_t_fp8,
ctx.fp8_meta["scaling_bwd"], fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_forward,
dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT2],
fp8_dtype_backward, fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
) )
else: else:
# FC2 DGRAD; Unconditional
fc2_dgrad, _, _ = gemm(
fc2_weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
gelu=not ctx.bias_gelu_nvfusion,
grad=True,
gelu_input=fc1_out,
)
# FC2 WGRAD
if fc2_weight.requires_grad: if fc2_weight.requires_grad:
gelu_out_c = cast_from_fp8( fc2_wgrad, fc2_bias_grad, _ = gemm(
gelu_out, gelu_out,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc2_wgrad, _, _ = gemm(
gelu_out_c,
grad_output, grad_output,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
...@@ -1984,172 +2021,114 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1984,172 +2021,114 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation, fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
if ctx.fuse_wgrad_accumulation
else None,
) )
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( if ctx.bias_gelu_nvfusion:
fc2_dgrad, fc1_out, fc1_bias fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
) else:
dgelu = fc2_dgrad
dgelu = cast_to_fp8(
dgelu_no_fp8,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
dgelu_t = None
# FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT2],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# FC2 DGRAD; Unconditional
fc2_dgrad, _, _ = gemm(
fc2_weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
gelu=not ctx.bias_gelu_nvfusion,
grad=True,
gelu_input=fc1_out,
)
# FC2 WGRAD # FC1 DGRAD: Unconditional
if fc2_weight.requires_grad: fc1_dgrad, _, _ = gemm(
fc2_wgrad, fc2_bias_grad, _ = gemm( fc1_weight,
gelu_out, dgelu,
grad_output,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
layout="NT", layout="NN",
grad=True, grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
if ctx.bias_gelu_nvfusion: # Overlap dgrad-RS/AR with wgrad
fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) if ctx.set_parallel_mode and ctx.sequence_parallel:
else: handle.wait()
dgelu = fc2_dgrad fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
# FC1 DGRAD: Unconditional )
fc1_dgrad, _, _ = gemm( elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_weight, fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
dgelu,
ctx.activation_dtype, if fc1_weight.requires_grad:
get_workspace(), if ctx.fp8:
layout="NN", # FC1 WGRAD
grad=True, if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = fp8_gemm(
# Overlap dgrad-RS/AR with wgrad ln_out_total_t,
if ctx.set_parallel_mode and ctx.sequence_parallel: fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
handle.wait() fp8_dtype_forward,
fc1_dgrad, handle = reduce_scatter_along_first_dim( dgelu_t,
fc1_dgrad, ctx.tp_group, async_op=True ctx.fp8_meta["scaling_bwd"].scale_inv[
) tex.FP8BwdTensors.GRAD_OUTPUT2
elif ctx.set_parallel_mode and ctx.tensor_parallel: ],
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) fp8_dtype_backward,
ctx.activation_dtype,
if fc1_weight.requires_grad: get_workspace(),
if ctx.fp8: accumulate=accumulate_wgrad_into_param_main_grad,
# FC1 WGRAD fp32_output=ctx.fuse_wgrad_accumulation,
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: out=fc1_weight.main_grad
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) if ctx.fuse_wgrad_accumulation
fc1_wgrad = fp8_gemm( else None,
ln_out_total_t, use_split_accumulator=_2X_ACC_WGRAD,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT], )
fp8_dtype_forward, else:
dgelu_t, ln_out_total_c = cast_from_fp8(
ctx.fp8_meta["scaling_bwd"].scale_inv[ ln_out_total,
tex.FP8BwdTensors.GRAD_OUTPUT2 ctx.fp8_meta["scaling_fwd"],
], tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_backward, fp8_dtype_forward,
ctx.activation_dtype, TE_DType[ctx.activation_dtype],
get_workspace(), )
accumulate=accumulate_wgrad_into_param_main_grad, fc1_wgrad, _, _ = gemm(
fp32_output=ctx.fuse_wgrad_accumulation, ln_out_total_c,
out=fc1_weight.main_grad dgelu_no_fp8,
if ctx.fuse_wgrad_accumulation ctx.activation_dtype,
else None, get_workspace(),
use_split_accumulator=_2X_ACC_WGRAD, layout="NT",
) grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
)
else: else:
ln_out_total_c = cast_from_fp8( # FC1 WGRAD
fc1_wgrad_outputs = gemm(
ln_out_total, ln_out_total,
ctx.fp8_meta["scaling_fwd"], dgelu,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc1_wgrad, _, _ = gemm(
ln_out_total_c,
dgelu_no_fp8,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation, fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
if ctx.fuse_wgrad_accumulation
else None,
) )
else:
# FC1 WGRAD
fc1_wgrad_outputs = gemm(
ln_out_total,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if ctx.bias_gelu_nvfusion: if ctx.bias_gelu_nvfusion:
fc1_wgrad, _, _ = fc1_wgrad_outputs fc1_wgrad, _, _ = fc1_wgrad_outputs
else: else:
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear # Column Parallel Linear
if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None: if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
handle.wait() handle.wait()
# LayerNorm gradient
d_ln_out = fc1_dgrad.view(inputmat.shape)
# Residual gradient # LayerNorm gradient
if ctx.return_layernorm_output: d_ln_out = fc1_dgrad.view(inputmat.shape)
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd( # Residual gradient
d_ln_out, inputmat, mu, rsigma, ln_weight if ctx.return_layernorm_output:
) d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
if not ctx.use_bias: dxmat, dgamma, dbeta = tex.layernorm_bwd(
fc2_bias_grad = None d_ln_out, inputmat, mu, rsigma, ln_weight
)
TransformerEngineBaseModule.post_backward( if not ctx.use_bias:
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group fc2_bias_grad = None
)
return ( return (
dxmat.view(ctx.inp_shape), dxmat.view(ctx.inp_shape),
...@@ -2420,36 +2399,33 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2420,36 +2399,33 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced) produced)
""" """
inp = self.pre_forward(inp, num_gemms=2) with self.prepare_forward(inp, num_gemms=2) as inp:
out = _LayerNormMLP.apply(
out = _LayerNormMLP.apply( inp,
inp, self.layer_norm_weight,
self.layer_norm_weight, self.layer_norm_bias,
self.layer_norm_bias, self.fc1_weight,
self.fc1_weight, self.weight1_fp8 if self.fp8 else None,
self.weight1_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None, self.fc1_bias,
self.fc1_bias, self.fc2_weight,
self.fc2_weight, self.weight2_fp8 if self.fp8 else None,
self.weight2_fp8 if self.fp8 else None, self.weight2_t_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None, self.fc2_bias,
self.fc2_bias, self.use_bias,
self.use_bias, self.eps,
self.eps, is_first_microbatch,
is_first_microbatch, self.fp8,
self.fp8, self.fp8_meta,
self.fp8_meta, self.fuse_wgrad_accumulation,
self.fuse_wgrad_accumulation, self.tp_group,
self.tp_group, self.sequence_parallel,
self.sequence_parallel, self.tp_size > 1,
self.tp_size > 1, self.activation_dtype,
self.activation_dtype, self.return_layernorm_output,
self.return_layernorm_output, self.bias_gelu_nvfusion,
self.bias_gelu_nvfusion, self.set_parallel_mode,
self.set_parallel_mode, )
)
self.post_forward()
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment