Unverified Commit 73939472 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] MXFP8 kernel for grouped tensors (#2586)



* Rebased to main
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed the year to 2026
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added compilation guards
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added BWD pass
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added dbias and dact tests. Refactoring.
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added grouped MXFP8 DACT and ACT API and tests
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed a typo
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixes per the review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* More fixes from the review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Fixes per the review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Relaxed requirement for last dim from mod128 to mod32
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added alignment checks when tensor descriptors are modified
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
parent 71971e33
......@@ -11,6 +11,7 @@ add_executable(test_operator
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_mxfp8_grouped.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
......
This diff is collapsed.
......@@ -13,6 +13,14 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_gelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
......@@ -20,6 +28,20 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
......@@ -54,6 +90,15 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_qgelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
......@@ -61,6 +106,20 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
......
......@@ -13,6 +13,14 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
}
void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_relu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, relu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
......@@ -20,6 +28,20 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_drelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
......@@ -54,6 +90,15 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
}
void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_srelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, srelu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
......@@ -61,6 +106,20 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsrelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
......
......@@ -13,6 +13,14 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
}
void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_silu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, silu<fp32, fp32>>(input, output, nullptr,
stream);
}
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu);
......@@ -20,6 +28,20 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsilu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
......@@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
......
......@@ -26,6 +26,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_noop);
......@@ -60,6 +69,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr const NVTEGroupedTensor activation_input = nullptr;
dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
......
......@@ -18,6 +18,7 @@
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/group_quantize_mxfp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
......@@ -371,6 +372,89 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
}
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output);
const NVTEGroupedTensor activation = nullptr;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
mxfp8::group_quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output);
const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad);
const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
mxfp8::group_quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
......
This diff is collapsed.
......@@ -52,6 +52,16 @@ enum class NVTE_Activation_Type {
*/
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the GeLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -62,6 +72,16 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -72,6 +92,16 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -82,6 +112,16 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -92,6 +132,16 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -104,6 +154,18 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the GeLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -116,6 +178,18 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the SiLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -128,6 +202,18 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the ReLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -140,6 +226,18 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -152,6 +250,18 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......
......@@ -89,6 +89,17 @@ extern "C" {
*/
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts input grouped tensor to MXFP8.
* The type of quantized tensor in the output depends on the scaling mode of the output
* tensor. See file level comments.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in,out] output Output grouped MXFP8 tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream);
/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
* The type of quantized tensor in the output depends on the scaling mode of the output
......@@ -132,6 +143,26 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -155,6 +186,29 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of GeLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -178,6 +232,29 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of SiLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -201,6 +278,29 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of ReLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -224,6 +324,29 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Quick GeLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -247,6 +370,29 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Squared ReLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Casts input tensor from reduced to higher precision.
* If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING,
* the block dequantization (MXFP8) of the specified shape of the block will be used.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment