Unverified Commit aadd3e7c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Add NVTX to TE modules (#50)



* Add NVTX to TE modules
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix pylint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix NVTX in _prepare_backward
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to C API
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cpplint and link nvToolsExt
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add NVTX to GeGlu
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent eed1fa26
...@@ -22,7 +22,8 @@ disable=too-many-locals, ...@@ -22,7 +22,8 @@ disable=too-many-locals,
attribute-defined-outside-init, attribute-defined-outside-init,
global-statement, global-statement,
too-many-branches, too-many-branches,
global-variable-not-assigned global-variable-not-assigned,
redefined-argument-from-local
[TYPECHECK] [TYPECHECK]
ignored-modules=torch ignored-modules=torch
......
...@@ -40,9 +40,9 @@ add_library(transformer_engine SHARED ...@@ -40,9 +40,9 @@ add_library(transformer_engine SHARED
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include") target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
find_package(CUDAToolkit REQUIRED cublas) find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart) list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
......
...@@ -116,6 +116,7 @@ void dgeglu(const Tensor &grad, ...@@ -116,6 +116,7 @@ void dgeglu(const Tensor &grad,
void nvte_gelu(const NVTETensor input, void nvte_gelu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine; using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input), gelu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -125,6 +126,7 @@ void nvte_gelu(const NVTETensor input, ...@@ -125,6 +126,7 @@ void nvte_gelu(const NVTETensor input,
void nvte_geglu(const NVTETensor input, void nvte_geglu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine; using namespace transformer_engine;
geglu_cast(*reinterpret_cast<const Tensor*>(input), geglu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -135,6 +137,7 @@ void nvte_dgeglu(const NVTETensor grad, ...@@ -135,6 +137,7 @@ void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input, const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgeglu(*reinterpret_cast<const Tensor*>(grad), dgeglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "nvtx.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -285,6 +286,9 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -285,6 +286,9 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t); bool is_fp8_dtype(const DType t);
#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _ ## api_name ## _nvtx_wrapper(#api_name);
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
...@@ -1060,12 +1060,13 @@ void nvte_scaled_softmax_forward( ...@@ -1060,12 +1060,13 @@ void nvte_scaled_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; NVTE_API_CALL(nvte_scaled_softmax_forward);
scaled_softmax_forward( using namespace transformer_engine;
*reinterpret_cast<const Tensor*>(input), scaled_softmax_forward(
reinterpret_cast<Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(input),
scale_factor, reinterpret_cast<Tensor*>(softmax_results),
stream); scale_factor,
stream);
} }
...@@ -1076,13 +1077,14 @@ void nvte_scaled_softmax_backward( ...@@ -1076,13 +1077,14 @@ void nvte_scaled_softmax_backward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; NVTE_API_CALL(nvte_scaled_softmax_backward);
scaled_softmax_backward( using namespace transformer_engine;
*reinterpret_cast<Tensor*>(output_grads), scaled_softmax_backward(
*reinterpret_cast<const Tensor*>(incoming_grads), *reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(incoming_grads),
scale_factor, *reinterpret_cast<const Tensor*>(softmax_results),
stream); scale_factor,
stream);
} }
...@@ -1093,13 +1095,14 @@ void nvte_scaled_masked_softmax_forward( ...@@ -1093,13 +1095,14 @@ void nvte_scaled_masked_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; NVTE_API_CALL(nvte_scaled_masked_softmax_forward);
scaled_masked_softmax_forward( using namespace transformer_engine;
*reinterpret_cast<const Tensor*>(input), scaled_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(mask), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(mask),
scale_factor, reinterpret_cast<Tensor*>(softmax_results),
stream); scale_factor,
stream);
} }
...@@ -1110,11 +1113,12 @@ void nvte_scaled_masked_softmax_backward( ...@@ -1110,11 +1113,12 @@ void nvte_scaled_masked_softmax_backward(
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; NVTE_API_CALL(nvte_scaled_masked_softmax_backward);
scaled_masked_softmax_backward( using namespace transformer_engine;
*reinterpret_cast<Tensor*>(output_grads), scaled_masked_softmax_backward(
*reinterpret_cast<const Tensor*>(incoming_grads), *reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(incoming_grads),
scale_factor, *reinterpret_cast<const Tensor*>(softmax_results),
stream); scale_factor,
stream);
} }
...@@ -236,6 +236,7 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -236,6 +236,7 @@ void nvte_cublas_gemm(const NVTETensor A,
bool accumulate, bool accumulate,
bool use_split_accumulator, bool use_split_accumulator,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A); const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B); const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
......
...@@ -375,6 +375,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -375,6 +375,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const int multiprocessorCount, const int multiprocessorCount,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma), *reinterpret_cast<const Tensor*>(gamma),
...@@ -403,6 +404,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size ...@@ -403,6 +404,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
const int multiprocessorCount, const int multiprocessorCount,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier) { NVTETensor barrier) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), layernorm_bwd(*reinterpret_cast<const Tensor*>(dz),
*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(x),
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_
#define TRANSFORMER_ENGINE_COMMON_NVTX_H_
#include <string>
#include <nvToolsExt.h>
namespace transformer_engine::nvtx {
struct NVTXWrapper {
explicit NVTXWrapper(const std::string &name) {
nvtxRangePush(name.c_str());
}
~NVTXWrapper() {
nvtxRangePop();
}
};
} // namespace transformer_engine::nvtx
#endif // TRANSFORMER_ENGINE_COMMON_NVTX_H_
...@@ -287,11 +287,12 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size ...@@ -287,11 +287,12 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
using namespace transformer_engine; NVTE_API_CALL(nvte_rmsnorm_fwd);
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma), using namespace transformer_engine;
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream, rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
multiprocessorCount, reinterpret_cast<Tensor *>(workspace), epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), stream,
reinterpret_cast<Tensor *>(barrier)); multiprocessorCount, reinterpret_cast<Tensor *>(workspace),
reinterpret_cast<Tensor *>(barrier));
} }
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
...@@ -300,10 +301,11 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size ...@@ -300,10 +301,11 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream,
const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) {
using namespace transformer_engine; NVTE_API_CALL(nvte_rmsnorm_bwd);
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x), using namespace transformer_engine;
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma), rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma), *reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount, reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier)); reinterpret_cast<Tensor *>(dgamma_part), stream, multiprocessorCount,
reinterpret_cast<Tensor *>(workspace), reinterpret_cast<Tensor *>(barrier));
} }
...@@ -392,6 +392,7 @@ void nvte_cast_transpose(const NVTETensor input, ...@@ -392,6 +392,7 @@ void nvte_cast_transpose(const NVTETensor input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
......
...@@ -1566,6 +1566,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input, ...@@ -1566,6 +1566,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input), cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1582,6 +1583,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -1582,6 +1583,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input), cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input), *reinterpret_cast<const Tensor*>(gelu_input),
...@@ -1597,6 +1599,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, ...@@ -1597,6 +1599,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input), dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(geglu_input), *reinterpret_cast<const Tensor*>(geglu_input),
......
...@@ -339,6 +339,7 @@ void nvte_multi_cast_transpose(size_t num_tensors, ...@@ -339,6 +339,7 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* cast_output_list, NVTETensor* cast_output_list,
NVTETensor* transposed_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
std::vector<Tensor*> input_list_, std::vector<Tensor*> input_list_,
cast_output_list_, transposed_output_list_; cast_output_list_, transposed_output_list_;
......
...@@ -308,6 +308,7 @@ void transpose(const Tensor &input, ...@@ -308,6 +308,7 @@ void transpose(const Tensor &input,
void nvte_transpose(const NVTETensor input, void nvte_transpose(const NVTETensor input,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input), transpose(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
......
...@@ -95,6 +95,7 @@ void fp8_dequantize(const Tensor &input, ...@@ -95,6 +95,7 @@ void fp8_dequantize(const Tensor &input,
void nvte_fp8_quantize(const NVTETensor input, void nvte_fp8_quantize(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input), fp8_quantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
...@@ -104,6 +105,7 @@ void nvte_fp8_quantize(const NVTETensor input, ...@@ -104,6 +105,7 @@ void nvte_fp8_quantize(const NVTETensor input,
void nvte_fp8_dequantize(const NVTETensor input, void nvte_fp8_dequantize(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input), fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment