"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "f5412e5f5a804a14f80993310622b6088598412f"
Unverified Commit 89f94ba2 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Softmax docstrings and type fixes (#37)



* Softmax docs and type fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint whitespace
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* change API, better naming, const fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0291a608
...@@ -18,5 +18,6 @@ directly from C/C++, without Python. ...@@ -18,5 +18,6 @@ directly from C/C++, without Python.
cast.h <cast> cast.h <cast>
gemm.h <gemm> gemm.h <gemm>
layer_norm.h <layer_norm> layer_norm.h <layer_norm>
softmax.h <softmax>
transformer_engine.h <transformer_engine> transformer_engine.h <transformer_engine>
transpose.h <transpose> transpose.h <transpose>
..
Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
softmax.h
======
.. doxygenfile:: softmax.h
...@@ -349,7 +349,7 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -349,7 +349,7 @@ __global__ void scaled_masked_softmax_warp_forward(
template <typename input_t, typename output_t, typename acc_t, int log2_elements> template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward( __global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput, output_t *gradInput,
input_t *grad, const input_t *grad,
const input_t *output, const input_t *output,
acc_t scale, acc_t scale,
int micro_batch_size, int micro_batch_size,
...@@ -773,7 +773,7 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -773,7 +773,7 @@ void dispatch_scaled_masked_softmax_forward(
template<typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward( void dispatch_scaled_masked_softmax_backward(
output_t *grad_input, output_t *grad_input,
input_t *grad, const input_t *grad,
const input_t *output, const input_t *output,
const acc_t scale, const acc_t scale,
int query_seq_len, int query_seq_len,
...@@ -968,7 +968,8 @@ void scaled_softmax_forward( ...@@ -968,7 +968,8 @@ void scaled_softmax_forward(
} }
void scaled_softmax_backward( void scaled_softmax_backward(
const Tensor output_grads, Tensor output_grads,
const Tensor incoming_grads,
const Tensor softmax_results, const Tensor softmax_results,
float scale_factor, float scale_factor,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -983,7 +984,7 @@ void scaled_softmax_backward( ...@@ -983,7 +984,7 @@ void scaled_softmax_backward(
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>( dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type const*>(incoming_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr), reinterpret_cast<softmax_type const*>(softmax_results.dptr),
scale_factor, scale_factor,
query_seq_len, query_seq_len,
...@@ -1023,7 +1024,8 @@ void scaled_masked_softmax_forward( ...@@ -1023,7 +1024,8 @@ void scaled_masked_softmax_forward(
void scaled_masked_softmax_backward( void scaled_masked_softmax_backward(
const Tensor output_grads, Tensor output_grads,
const Tensor incoming_grads,
const Tensor softmax_results, const Tensor softmax_results,
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
...@@ -1038,7 +1040,7 @@ void scaled_masked_softmax_backward( ...@@ -1038,7 +1040,7 @@ void scaled_masked_softmax_backward(
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>( dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type const*>(incoming_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr), reinterpret_cast<softmax_type const*>(softmax_results.dptr),
scale_factor, scale_factor,
query_seq_len, query_seq_len,
...@@ -1068,14 +1070,16 @@ void nvte_scaled_softmax_forward( ...@@ -1068,14 +1070,16 @@ void nvte_scaled_softmax_forward(
void nvte_scaled_softmax_backward( void nvte_scaled_softmax_backward(
const NVTETensor output_grads, const NVTETensor incoming_grads,
const NVTETensor softmax_results, const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
scaled_softmax_backward( scaled_softmax_backward(
*reinterpret_cast<const Tensor*>(output_grads), *reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(incoming_grads),
*reinterpret_cast<const Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(softmax_results),
scale_factor, scale_factor,
stream); stream);
...@@ -1100,15 +1104,17 @@ void nvte_scaled_masked_softmax_forward( ...@@ -1100,15 +1104,17 @@ void nvte_scaled_masked_softmax_forward(
void nvte_scaled_masked_softmax_backward( void nvte_scaled_masked_softmax_backward(
const NVTETensor input, const NVTETensor incoming_grads,
NVTETensor softmax_results, const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
scaled_masked_softmax_backward( scaled_masked_softmax_backward(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(incoming_grads),
*reinterpret_cast<const Tensor*>(softmax_results),
scale_factor, scale_factor,
stream); stream);
} }
...@@ -248,7 +248,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -248,7 +248,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
template <typename input_t, typename output_t, typename acc_t, int log2_elements> template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward( __global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput, output_t *gradInput,
input_t *grad, const input_t *grad,
const input_t *output, const input_t *output,
acc_t scale, acc_t scale,
int micro_batch_size, int micro_batch_size,
...@@ -509,7 +509,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -509,7 +509,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
template<typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward( void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input, output_t *grad_input,
input_t *grad, const input_t *grad,
const input_t *output, const input_t *output,
const acc_t scale, const acc_t scale,
int softmax_elements, int softmax_elements,
...@@ -683,7 +683,8 @@ void scaled_upper_triang_masked_softmax_forward( ...@@ -683,7 +683,8 @@ void scaled_upper_triang_masked_softmax_forward(
void scaled_upper_triang_masked_softmax_backward( void scaled_upper_triang_masked_softmax_backward(
const Tensor output_grads, Tensor output_grads,
const Tensor incoming_grads,
const Tensor softmax_results, const Tensor softmax_results,
float scale_factor, float scale_factor,
cudaStream_t stream) { cudaStream_t stream) {
...@@ -695,7 +696,7 @@ void scaled_upper_triang_masked_softmax_backward( ...@@ -695,7 +696,7 @@ void scaled_upper_triang_masked_softmax_backward(
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
dispatch_scaled_upper_triang_masked_softmax_backward<softmax_type, softmax_type, float>( dispatch_scaled_upper_triang_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type const*>(incoming_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr), reinterpret_cast<softmax_type const*>(softmax_results.dptr),
scale_factor, scale_factor,
seq_len, seq_len,
...@@ -723,14 +724,16 @@ void nvte_scaled_upper_triang_masked_softmax_forward( ...@@ -723,14 +724,16 @@ void nvte_scaled_upper_triang_masked_softmax_forward(
void nvte_scaled_upper_triang_masked_softmax_backward( void nvte_scaled_upper_triang_masked_softmax_backward(
const NVTETensor output_grads, const NVTETensor incoming_grads,
const NVTETensor softmax_results, const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
scaled_upper_triang_masked_softmax_backward( scaled_upper_triang_masked_softmax_backward(
*reinterpret_cast<const Tensor*>(output_grads), *reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(incoming_grads),
*reinterpret_cast<const Tensor*>(softmax_results), *reinterpret_cast<const Tensor*>(softmax_results),
scale_factor, scale_factor,
stream); stream);
......
...@@ -15,7 +15,13 @@ ...@@ -15,7 +15,13 @@
extern "C" { extern "C" {
#endif #endif
/*! \brief Compute scaled softmax activation on the input.
*
* \param[in] input Input tensor for softmax.
* \param[out] softmax_results Output tensor.
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_softmax_forward( void nvte_scaled_softmax_forward(
const NVTETensor input, const NVTETensor input,
NVTETensor softmax_results, NVTETensor softmax_results,
...@@ -24,14 +30,35 @@ void nvte_scaled_softmax_forward( ...@@ -24,14 +30,35 @@ void nvte_scaled_softmax_forward(
); );
/*! \brief Compute the backward of the scaled softmax activation.
*
* - `incoming_grads` is the input tensor containing the gradients received from the following layer.
* - `softmax_results` is the output tensor of the corresponding forward softmax operation.
* - `output_grads` is the output tensor containing the computed gradients.
*
* \param[in] incoming_grads Input gradient tensor for backward.
* \param[in] softmax_results Output tensor of softmax forward.
* \param[out] output_grads Output tensor.
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_softmax_backward( void nvte_scaled_softmax_backward(
const NVTETensor output_grads, const NVTETensor incoming_grads,
const NVTETensor softmax_results, const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
); );
/*! \brief Compute scaled masked softmax activation on the input.
*
* \param[in] input Input tensor for softmax.
* \param[in] mask Mask for the input tensor.
* \param[out] softmax_results Output tensor.
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_masked_softmax_forward( void nvte_scaled_masked_softmax_forward(
const NVTETensor input, const NVTETensor input,
const NVTETensor mask, const NVTETensor mask,
...@@ -41,14 +68,34 @@ void nvte_scaled_masked_softmax_forward( ...@@ -41,14 +68,34 @@ void nvte_scaled_masked_softmax_forward(
); );
/*! \brief Compute the backward of the scaled masked softmax activation.
*
* - `incoming_grads` is the input tensor containing the gradients received from the following layer.
* - `softmax_results` is the output tensor of the corresponding forward softmax operation.
* - `output_grads` is the output tensor containing the computed gradients.
*
* \param[in] incoming_grads Input gradient tensor for backward.
* \param[in] softmax_results Output tensor of softmax forward.
* \param[out] output_grads Output tensor.
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_masked_softmax_backward( void nvte_scaled_masked_softmax_backward(
const NVTETensor input, const NVTETensor incoming_grads,
NVTETensor softmax_results, const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
); );
/*! \brief Compute scaled softmax activation using a 2D upper triangular mask on the input.
*
* \param[in] input Input tensor for softmax.
* \param[out] softmax_results Output tensor.
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_upper_triang_masked_softmax_forward( void nvte_scaled_upper_triang_masked_softmax_forward(
const NVTETensor input, const NVTETensor input,
NVTETensor softmax_results, NVTETensor softmax_results,
...@@ -57,9 +104,22 @@ void nvte_scaled_upper_triang_masked_softmax_forward( ...@@ -57,9 +104,22 @@ void nvte_scaled_upper_triang_masked_softmax_forward(
); );
/*! \brief Compute the backward of the scaled softmax activation using a 2D upper triangular mask.
*
* - `incoming_grads` is the input tensor containing the gradients received from the following layer.
* - `softmax_results` is the output tensor of the corresponding forward softmax operation.
* - `output_grads` is the output tensor containing the computed gradients.
*
* \param[in] incoming_grads Input gradient tensor for backward.
* \param[in] softmax_results Output tensor of softmax forward.
* \param[out] output_grads Output tensor.
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_upper_triang_masked_softmax_backward( void nvte_scaled_upper_triang_masked_softmax_backward(
const NVTETensor output_grads, const NVTETensor incoming_grads,
const NVTETensor softmax_results, const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor, float scale_factor,
cudaStream_t stream cudaStream_t stream
); );
......
...@@ -537,8 +537,9 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, ...@@ -537,8 +537,9 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward( nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream()); scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads; return output_grads;
...@@ -608,8 +609,9 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, ...@@ -608,8 +609,9 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward( nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream()); scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads; return output_grads;
...@@ -671,8 +673,10 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -671,8 +673,10 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results); auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(), nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(),
softmax_results_cu.data(), softmax_results_cu.data(),
output_grads_cu.data(),
scale_factor, scale_factor,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment