Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
...@@ -19,19 +19,9 @@ try: ...@@ -19,19 +19,9 @@ try:
except (ImportError, StopIteration) as e: except (ImportError, StopIteration) as e:
pass pass
try:
from . import paddle
except (ImportError, StopIteration) as e:
pass
try: try:
import transformer_engine_jax import transformer_engine_jax
except ImportError: except ImportError:
pass pass
try:
import transformer_engine_paddle
except ImportError:
pass
__version__ = str(metadata.version("transformer_engine")) __version__ = str(metadata.version("transformer_engine"))
...@@ -6,13 +6,17 @@ cmake_minimum_required(VERSION 3.21) ...@@ -6,13 +6,17 @@ cmake_minimum_required(VERSION 3.21)
# Language options # Language options
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif() endif()
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON) set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug") if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
endif() endif()
# Hide non-necessary symbols in shared object. # Hide non-necessary symbols in shared object.
...@@ -78,6 +82,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -78,6 +82,7 @@ list(APPEND transformer_engine_SOURCES
util/cuda_runtime.cpp util/cuda_runtime.cpp
util/rtc.cpp util/rtc.cpp
util/system.cpp util/system.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
......
...@@ -4,111 +4,71 @@ ...@@ -4,111 +4,71 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
/*! \file activation_template.h
* \brief Activation functions template.
*/
#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/activation.h> #include <transformer_engine/activation.h>
#include "../common.h" #include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
namespace transformer_engine { namespace transformer_engine {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)> template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
CheckInputTensor(input, "act_lu_input"); using namespace detail;
CheckOutputTensor(*output, "act_lu_output"); constexpr bool IS_DBIAS = false;
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); constexpr bool IS_DACT = false;
const size_t tot_elts = product(input.data.shape); constexpr bool IS_ACT = true;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
input.data.dtype, IType, workspace, stream);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
} }
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)> template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
CheckInputTensor(input, "dact_lu_input"); cudaStream_t stream) {
CheckInputTensor(grad, "dact_lu_input_grad"); using namespace detail;
CheckOutputTensor(*output, "dact_lu_output"); constexpr bool IS_DBIAS = false;
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); constexpr bool IS_DACT = true;
NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match."); constexpr bool IS_ACT = false;
const size_t tot_elts = product(input.data.shape); constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
input.data.dtype, IType, workspace, stream);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
} }
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)> template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input"); using namespace detail;
CheckOutputTensor(*output, "gated_act_output"); constexpr bool IS_DGATED = false;
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); constexpr NVTETensor grad = nullptr;
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, ComputeType, Param, OP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), output->data.shape[0],
output->data.shape[1], {},
stream);); // NOLINT(*)
); // NOLINT(*)
} }
template <typename ComputeType, typename Param, ComputeType (*OP1)(ComputeType, const Param &), template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*OP2)(ComputeType, const Param &)> ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
CheckInputTensor(grad, "dgated_act_grad"); cudaStream_t stream) {
CheckInputTensor(input, "dgated_act_input"); using namespace detail;
CheckOutputTensor(*output, "dgated_act_output"); constexpr bool IS_DGATED = true;
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, ComputeType, Param, OP1, OP2>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), grad.data.shape[0], grad.data.shape[1],
{},
stream);); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
...@@ -3,69 +3,58 @@ ...@@ -3,69 +3,58 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../util/math.h" #include "../util/math.h"
#include "./activation_template.h" #include "./activation_template.h"
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu); NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu); NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu); NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu); NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>( dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu); NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu); NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu); NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu); NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>( dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
...@@ -10,63 +10,51 @@ ...@@ -10,63 +10,51 @@
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_relu); NVTE_API_CALL(nvte_relu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu); NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, drelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu); NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu); NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>( dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_srelu); NVTE_API_CALL(nvte_srelu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu); NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu); NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu); NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>( dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
...@@ -10,31 +10,25 @@ ...@@ -10,31 +10,25 @@
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_silu); NVTE_API_CALL(nvte_silu);
using namespace transformer_engine; using namespace transformer_engine;
act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu); NVTE_API_CALL(nvte_dsilu);
using namespace transformer_engine; using namespace transformer_engine;
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad), dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu); NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input), gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
reinterpret_cast<Tensor*>(output), stream);
} }
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu); NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>( dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, stream);
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
} }
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#define HALF_BYTES 2 #define HALF_BYTES 2
#define UB_MAX_SM 32 #define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using namespace std::placeholders; using namespace std::placeholders;
namespace transformer_engine { namespace transformer_engine {
...@@ -40,8 +42,9 @@ bool ubuf_built_with_mpi() { ...@@ -40,8 +42,9 @@ bool ubuf_built_with_mpi() {
CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle, int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin, int comm_cga_size, int gemm_priority, int comm_priority,
bool use_ce, bool atomic_gemm) { int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm) {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!_comm_created) { if (!_comm_created) {
if (myrank == 0) { if (myrank == 0) {
...@@ -59,9 +62,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -59,9 +62,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_num_comm_sm = num_comm_sm; _num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size; _cga_size = comm_cga_size;
if (gemm_priority == 0 && comm_priority == 0) {
transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority);
} else {
_gemm_priority = gemm_priority;
_comm_priority = comm_priority;
}
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream; cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1)); NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority));
_stream_compute.push_back(std::move(stream)); _stream_compute.push_back(std::move(stream));
} }
...@@ -130,6 +139,73 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -130,6 +139,73 @@ CommOverlapCore::~CommOverlapCore() {
} }
} }
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk;
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type);
auto param_dptr = reinterpret_cast<char *>(param.data_ptr);
auto param_dtype = static_cast<DType>(param.dtype);
auto param_shape = AS_VECTOR(param.shape);
if (param_dptr != nullptr) {
if (param_type == NVTETensorParam::kNVTERowwiseData ||
param_type == NVTETensorParam::kNVTEColumnwiseData) {
// Offset data pointer
param_dptr += chunk_offset * typeToSize(param_dtype);
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back();
param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim);
}
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? source.shape().data[0]
: source.columnwise_shape().data[0];
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? chunk_shape.front()
: chunk_shape.back();
auto chunk_scale_start = chunk_offset / 32;
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32;
auto chunk_scale_size = chunk_scale_end - chunk_scale_start;
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
}
// Set chunked source parameters into the chunked tensor output
chunk.set_parameter(param_type, reinterpret_cast<void *>(param_dptr), param_dtype,
param_shape);
}
}
return chunk;
}
TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source,
size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
// Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape);
// Update chunk with offset data pointers from the communication buffer
auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size());
if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(), chunk.shape());
}
if (chunk.columnwise_dptr() != nullptr) {
chunk.set_columnwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(),
chunk.columnwise_shape());
}
return chunk;
}
/*************************************************************************************************** /***************************************************************************************************
* Comm+GEMM Overlap Base (Pipelined / Collective) * Comm+GEMM Overlap Base (Pipelined / Collective)
**************************************************************************************************/ **************************************************************************************************/
...@@ -138,11 +214,14 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -138,11 +214,14 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
int myrank, int numranks, int mylocal, int numlocal, int mynode, int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle, int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin, int comm_cga_size, int gemm_priority, int comm_priority,
bool atomic_gemm) int num_comm_sm, bool set_sm_margin, bool atomic_gemm,
bool rs_overlap_first_gemm)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, false, atomic_gemm) { gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) {
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0); _rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
"Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ",
...@@ -155,7 +234,8 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -155,7 +234,8 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1)); NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
} }
...@@ -168,8 +248,8 @@ CommOverlapBase::~CommOverlapBase() { ...@@ -168,8 +248,8 @@ CommOverlapBase::~CommOverlapBase() {
** Bulk GEMM + COMM ** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf ** This function assumes the communication input is pre-copied to _ubuf
*/ */
void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B,
TensorWrapper &D, TensorWrapper &bias, bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output, CommOverlapType comm_type, TensorWrapper &rs_output,
...@@ -196,7 +276,7 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper ...@@ -196,7 +276,7 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2); assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event); (cudaEvent_t)_comm_launch_event);
} else { } else {
...@@ -221,20 +301,20 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper ...@@ -221,20 +301,20 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
/* /*
** Split FPROP GEMM + ReduceScatter ** Split FPROP GEMM + ReduceScatter
*/ */
void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa,
bool transb, TensorWrapper &D, TensorWrapper &bias, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &workspace, bool grad, bool accumulate,
bool gemm_overlap, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) { cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
// Get GEMM dimensions // Get GEMM dimensions
size_t m = A.size(0); size_t m = transa ? A.size(0) : A.size(1);
size_t k = A.size(1); size_t k = transa ? A.size(1) : A.size(0);
size_t n = B.size(0); size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits; size_t m_chunk = m / _num_splits;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
...@@ -255,9 +335,8 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens ...@@ -255,9 +335,8 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); auto output_d = get_buffer_chunk_like(D, 0, {n, m});
auto workspace_chunk = auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate, transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(),
...@@ -269,11 +348,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens ...@@ -269,11 +348,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
_ub_comm->sms = UB_MAX_SM; _ub_comm->sms = UB_MAX_SM;
} }
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type, D.dtype(), fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>( reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits,
&counter_ptr[i], _ub_comm, _stream_comm);); &counter_ptr[i], _ub_comm, _stream_comm););
} else { } else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
...@@ -282,11 +360,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens ...@@ -282,11 +360,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
} }
} else if (_rs_kernel_type == 2) { } else if (_rs_kernel_type == 2) {
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type, D.dtype(), fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>( reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, _stream_comm);); counter_ptr, _ub_comm, _stream_comm););
} else { } else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m,
...@@ -299,7 +376,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens ...@@ -299,7 +376,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type, D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(rs_output_ptr, _ubuf_scale_inv, reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(rs_output_ptr, D.scale_inv(),
_ub_reg, i * m_chunk, m_chunk, n, m, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, _stream_comm);); _ub_comm, _stream_comm););
} else { } else {
...@@ -321,34 +398,24 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens ...@@ -321,34 +398,24 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
/* /*
** Split FPROP GEMM + ReduceScatter ** Split FPROP GEMM + ReduceScatter
*/ */
void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
TensorWrapper &D, TensorWrapper &bias, bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator, bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output, TensorWrapper &rs_output, cudaStream_t stream_main) {
cudaStream_t stream_main) {
// Get GEMM dimensions // Get GEMM dimensions
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
size_t m = A.size(0); size_t m = transa ? A.size(0) : A.size(1);
size_t k = A.size(1); size_t k = transa ? A.size(1) : A.size(0);
size_t n = B.size(0); size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits; size_t m_chunk = m / _num_splits;
size_t input_a_chunk_size = m_chunk * k; size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk; size_t output_chunk_size = n * m_chunk;
size_t bias_chunk_size = m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *bias_chunk_ptr = reinterpret_cast<char *>(bias.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
// Catch up the default torch stream // Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
...@@ -358,39 +425,23 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap ...@@ -358,39 +425,23 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
if (gemm_overlap) { char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
auto input_a_chunk = if (_rs_overlap_first_gemm) {
TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k});
auto output_chunk = auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
auto bias_chunk =
TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr); nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
auto workspace_chunk = TensorWrapper(
workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]); use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) { for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size(); input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
output_buf_chunk_ptr += output_chunk_size * D.element_size(); output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
if (bias_chunk_ptr != nullptr) { workspace_chunk = get_tensor_chunk(
bias_chunk_ptr += bias_chunk_size * bias.element_size(); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
}
char *workspace_chunk_ptr = nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
input_a_chunk = TensorWrapper(reinterpret_cast<void *>(input_a_chunk_ptr), {m_chunk, k},
A.dtype(), nullptr, nullptr, A.scale_inv());
output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr), {n, m_chunk},
D.dtype(), D.amax(), D.scale(), nullptr);
bias_chunk = TensorWrapper(reinterpret_cast<void *>(bias_chunk_ptr), {m_chunk}, bias.dtype(),
nullptr, nullptr, nullptr);
workspace_chunk = TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()]);
...@@ -401,11 +452,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap ...@@ -401,11 +452,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
// Communication chunk // Communication chunk
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type, D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm);); _ub_comm, _stream_comm););
} else { } else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
...@@ -422,12 +472,11 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap ...@@ -422,12 +472,11 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
// Last communication chunk with max SM // Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM; _ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type, D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk,
m_chunk, n, m, _ub_comm, _stream_comm);); n, m, _ub_comm, _stream_comm););
} else { } else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m, (_num_splits - 1) * output_chunk_size, m_chunk, n, m,
...@@ -435,20 +484,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap ...@@ -435,20 +484,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
} }
} else { } else {
for (int i = 0; i < _num_splits; i++) { for (int i = 0; i < _num_splits; i++) {
char *workspace_chunk_ptr = auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
auto workspace_chunk = get_tensor_chunk(
auto input_a_chunk = TensorWrapper(reinterpret_cast<void *>(input_a_chunk_ptr), {m_chunk, k}, workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
{n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto bias_chunk = TensorWrapper(reinterpret_cast<void *>(bias_chunk_ptr), {m_chunk},
bias.dtype(), nullptr, nullptr, nullptr);
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]); _stream_compute[i % _stream_compute.size()]);
...@@ -461,11 +502,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap ...@@ -461,11 +502,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
_ub_comm->sms = UB_MAX_SM; _ub_comm->sms = UB_MAX_SM;
} }
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type, D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm);); _ub_comm, _stream_comm););
} else { } else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
...@@ -473,11 +513,6 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap ...@@ -473,11 +513,6 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
} }
rs_output_ptr += m_chunk * rs_output.element_size(); rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
if (bias_chunk_ptr != nullptr) {
bias_chunk_ptr += bias_chunk_size * bias.element_size();
}
} }
} }
...@@ -499,11 +534,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -499,11 +534,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
int mynode, int numnodes, int tp_size, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams, CommOverlapType comm_type, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin, int comm_cga_size, int gemm_priority, int comm_priority,
bool use_ce, bool atomic_gemm, bool aggregate) int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm, bool aggregate)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
_is_p2p = true; _is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS; _is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate; _aggregate = aggregate;
...@@ -552,8 +589,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -552,8 +589,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
} }
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1)); for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1)); cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream));
}
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
} }
...@@ -562,7 +604,22 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ...@@ -562,7 +604,22 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send); cudaEventDestroy(_stop_send);
cudaStreamDestroy(_stream_recv); cudaStreamDestroy(_stream_recv);
cudaStreamDestroy(_stream_send); for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]);
}
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) {
// Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape()));
// Update chunk with offset data pointers from the communication buffer
if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape());
}
if (chunk.columnwise_dptr() != nullptr) {
chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape());
}
return chunk;
} }
/* /*
...@@ -570,12 +627,10 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ...@@ -570,12 +627,10 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases. ** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/ */
void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, void CommOverlapP2PBase::atomic_gemm_overlap_ag(
bool transb, TensorWrapper &D, TensorWrapper &bias, const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &pre_gelu_out, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) {
bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -583,8 +638,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ...@@ -583,8 +638,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
// Get GEMM dimensions between TN and NN input layouts // Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1); const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t n = _ubuf.size(0); const size_t n_chunk = _ubufs[0].size(0);
const size_t n_chunk = n / _tp_size;
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
...@@ -594,7 +648,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ...@@ -594,7 +648,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
void *D_buffer_ptr; void *D_buffer_ptr;
int D_chunk_bytes = n_chunk * m * D.element_size(); int D_chunk_bytes = n_chunk * m * D.element_size();
NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main));
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(),
D.scale_inv(), D.scale_inv_shape(), D.scaling_mode());
// Reset atomic counters // Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr()); int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
...@@ -602,13 +657,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ...@@ -602,13 +657,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
// Catch up the default torch stream // Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape()));
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk = auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
for (int i = 0; i < _tp_size - 1; i++) { for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current // Set the userbuffer id. Buffer under send is the input for the current
...@@ -649,8 +703,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ...@@ -649,8 +703,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send)); cudaMemcpyDeviceToDevice, _stream_send[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
} }
...@@ -674,11 +728,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ...@@ -674,11 +728,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases. ** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/ */
void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
bool transb, TensorWrapper &D, TensorWrapper &bias, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &workspace, bool grad, bool accumulate,
TensorWrapper &B_copy, cudaStream_t stream_main) { bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -691,24 +746,20 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -691,24 +746,20 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0; const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (n_chunk * m) * D.element_size(); size_t input_chunk_size = n_chunk * k;
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; size_t output_chunk_size = n_chunk * m;
// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.dptr());
char *pre_gelu_out_ptr = reinterpret_cast<char *>(pre_gelu_out.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
} }
if (_aggregate) { if (_aggregate) {
const int num_steps = _tp_size / 2; const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.dptr()); input_chunk_size *= 2;
output_chunk_size *= 2;
// Initial 1X input chunk exchange between neighboring peers // Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id; int send_chunk_id = _tp_id;
...@@ -717,11 +768,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -717,11 +768,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
int recv_offset = comm_bytes * recv_chunk_id; int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
_stream_send); _stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
_stream_recv); _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
...@@ -736,27 +787,15 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -736,27 +787,15 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
recv_offset = comm_bytes * recv_chunk_id; recv_offset = comm_bytes * recv_chunk_id;
// GEMM // GEMM
char *input_b_chunk_ptr = input_b_ptr + send_offset;
auto input_b_chunk = auto input_b_chunk =
TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k});
nullptr, nullptr, B.scale_inv()); auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m});
auto aux_chunk =
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); (do_gelu)
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr), ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
{n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); : TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk(
char *aux_chunk_ptr = workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
auto aux_chunk_shape =
(do_gelu) ? std::vector<size_t>{n_chunk * 2, m} : std::vector<size_t>{0};
auto aux_chunk = TensorWrapper(reinterpret_cast<void *>(aux_chunk_ptr), aux_chunk_shape,
pre_gelu_out.dtype());
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
...@@ -766,11 +805,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -766,11 +805,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
if (i < num_steps - 1) { if (i < num_steps - 1) {
// P2P communication // P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, _stream_send); next_rank, _stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, _stream_recv); prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) { } else if (B_copy.numel() > 0) {
...@@ -778,7 +817,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -778,7 +817,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send)); cudaMemcpyDeviceToDevice, _stream_send[0]));
} }
} }
} else { } else {
...@@ -793,24 +832,14 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -793,24 +832,14 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
int recv_offset = comm_bytes * recv_chunk_id; int recv_offset = comm_bytes * recv_chunk_id;
// GEMM // GEMM
auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k});
nullptr, nullptr, B.scale_inv()); auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m});
auto aux_chunk =
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); (do_gelu)
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr), {n_chunk, m}, ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
D.dtype(), D.amax(), D.scale(), nullptr); : TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk(
char *aux_chunk_ptr = workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
auto aux_chunk_shape = (do_gelu) ? std::vector<size_t>{n_chunk, m} : std::vector<size_t>{0};
auto aux_chunk = TensorWrapper(reinterpret_cast<void *>(aux_chunk_ptr), aux_chunk_shape,
pre_gelu_out.dtype());
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
...@@ -820,11 +849,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -820,11 +849,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
if (i < _tp_size - 1) { if (i < _tp_size - 1) {
// P2P communication // P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, _stream_send); _next_rank, _stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, _stream_recv); _prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) { } else if (B_copy.numel() > 0) {
...@@ -832,7 +861,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -832,7 +861,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send)); cudaMemcpyDeviceToDevice, _stream_send[0]));
} }
} }
} }
...@@ -842,7 +871,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -842,7 +871,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
...@@ -851,13 +880,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW ...@@ -851,13 +880,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
/* /*
** Split ReduceScatter + GEMM using P2P communication ** Split ReduceScatter + GEMM using P2P communication
*/ */
void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, void CommOverlapP2PBase::atomic_gemm_overlap_rs(
bool transb, TensorWrapper &D, TensorWrapper &bias, const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &pre_gelu_out, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
bool accumulate, bool use_split_accumulator, cudaStream_t stream_main) {
TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -876,14 +903,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T ...@@ -876,14 +903,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T
// Atomic GEMM // Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks. // Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape()));
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.data(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate, transa, transb, grad, workspace.data(), accumulate, use_split_accumulator,
use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(), _math_sms, 0, _tp_size, true, _counter.data(), stream_main);
stream_main);
// P2P communication chunk // P2P communication chunk
for (int i = 1; i < _tp_size; i++) { for (int i = 1; i < _tp_size; i++) {
...@@ -907,10 +930,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T ...@@ -907,10 +930,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type, D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size,
_ubufs[0].numel(), stream_main);); _ubufs[0].numel(), stream_main););
} else { } else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
...@@ -921,31 +943,33 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T ...@@ -921,31 +943,33 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T
/* /*
** Split ReduceScatter + GEMM using P2P communication ** Split ReduceScatter + GEMM using P2P communication
*/ */
void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
bool transb, TensorWrapper &D, TensorWrapper &bias, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &workspace, bool grad, bool accumulate,
TensorWrapper &rs_output, cudaStream_t stream_main) { bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size; _ub_comm->cga_size = _cga_size;
size_t k = A.size(1);
size_t n = B.size(0);
// Get communication and GEMM input chunk sizes // Get communication and GEMM input chunk sizes
size_t n_chunk = n / _tp_size; size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n_chunk = _ubufs[0].size(0);
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers // Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.dptr()); size_t input_chunk_size = n_chunk * k;
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr()); size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Catch up the main stream // Catch up the main stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); for (size_t i = 0; i < _stream_send.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[i], _start_compute, 0));
}
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
...@@ -954,36 +978,30 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW ...@@ -954,36 +978,30 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW
// GEMM and send/recv chunks // GEMM and send/recv chunks
for (int i = 0; i < _tp_size; i++) { for (int i = 0; i < _tp_size; i++) {
// GEMM chunk // GEMM chunk
int stream_id = i % _stream_compute.size();
int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; int input_b_chunk_id = (_tp_id + i + 1) % _tp_size;
char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes);
auto input_b_chunk = TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk, k},
B.dtype(), nullptr, nullptr, B.scale_inv());
auto output_chunk =
TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr);
char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk = auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr), get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); use_split_accumulator, _math_sms, _stream_compute[stream_id]);
if (i > 0) { if (i > 0) {
// P2P communication chunk // P2P communication chunk
int prev_stream_id = (i - 1) % _stream_compute.size();
int send_offset = comm_bytes * (i - 1); int send_offset = comm_bytes * (i - 1);
int recv_offset = comm_bytes * (i - 1 + _tp_size); int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id]));
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
_stream_send); _stream_send[prev_stream_id]);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
_stream_recv); _stream_recv);
} }
...@@ -993,8 +1011,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW ...@@ -993,8 +1011,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
...@@ -1002,11 +1022,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW ...@@ -1002,11 +1022,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type, D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size,
_ubufs[0].numel(), stream_main);); _ubufs[0].numel(), stream_main););
} else { } else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <stdio.h> #include <stdio.h>
#include <unistd.h> #include <unistd.h>
#include "common/util/system.h"
#include "userbuffers.h" #include "userbuffers.h"
#define MAX_THREADS 1024 #define MAX_THREADS 1024
......
...@@ -6,27 +6,138 @@ ...@@ -6,27 +6,138 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <bit>
#include "./common.h" #include "./common.h"
#include "./utils.cuh" #include "./utils.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
namespace transformer_engine { namespace transformer_engine {
namespace { namespace {
__global__ void __launch_bounds__(1) __global__ void __launch_bounds__(1)
update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr, update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr,
float* __restrict__ scale_inv_ptr) { float *__restrict__ scale_inv_ptr) {
const float scale = scale_ptr == nullptr ? 1 : *scale_ptr; const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
reciprocal<float>(scale_inv_ptr, scale); reciprocal<float>(scale_inv_ptr, scale);
} }
} // namespace } // namespace
void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) { void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
if (t->scale_inv.dptr != nullptr) { if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) {
NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv.");
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float*>(t->scale.dptr), reinterpret_cast<float*>(t->scale_inv.dptr)); reinterpret_cast<const float *>(t->scale.dptr),
reinterpret_cast<float *>(t->scale_inv.dptr));
} }
} }
void checkCuDriverContext(CUstream stream) {
CUcontext ctx;
const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
switch (driver_status) {
case CUDA_SUCCESS:
break;
case CUDA_ERROR_INVALID_CONTEXT:
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &ctx, current_device);
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, ctx);
break;
default:
const char *desc_NVTE_CHECK_CUDA_DRIVER;
cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER);
NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER);
}
}
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = {
{DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
{DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
{DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
{DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
return dtypeMapping.at(dtype);
}
inline bool isPointerAligned(const void *const ptr, const int alignment) {
const uint64_t ptr_as_uint = reinterpret_cast<uint64_t>(ptr);
return ptr_as_uint % alignment == 0;
}
// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size) {
// Get a function pointer to the cuTensorMapEncodeTiled driver API
static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() {
void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(driver_ptr);
}();
// rank is the number of dimensions of the array
constexpr uint32_t rank = 2;
uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next
uint64_t stride[rank - 1] = {stride_elems * type_size};
// The boxSize is the size of the shared memory buffer that is used as the
// source/destination of a TMA transfer
uint32_t boxSize[rank] = {shmemX, shmemY};
// The distance between elements in units of sizeof(element)
uint32_t elemStride[rank] = {1, 1};
const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
void *dataPtr =
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + offset_elems * type_size);
constexpr int TMA_gmem_alignment = 16; // Alignment of the global memory address
NVTE_CHECK(isPointerAligned(dataPtr, TMA_gmem_alignment),
"Tensor data pointer must be 16B aligned");
const int TMA_needed_size = TMA_gmem_alignment / type_size;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size,
"-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
// Create the tensor descriptor.
NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
&tensorMap, // CUtensorMap *tensorMap,
tensorDataType,
rank, // cuuint32_t tensorRank,
dataPtr, // void *globalAddress,
size, // const cuuint64_t *globalDim,
stride, // const cuuint64_t *globalStrides,
boxSize, // const cuuint32_t *boxDim,
elemStride, // const cuuint32_t *elementStrides,
// Interleave patterns can be used to accelerate loading of values that
// are less than 4 bytes long.
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
// Swizzling can be used to avoid shared memory bank conflicts.
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
// L2 Promotion can be used to widen the effect of a cache-policy to a wider
// set of L2 cache lines.
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
// CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
// Any element that is outside of bounds will be set to zero by the TMA transfer.
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}
bool is_supported_by_CC_100() {
int deviceComputeCapability = cuda::sm_arch(cuda::current_device());
return deviceComputeCapability >= 100;
}
} // namespace transformer_engine } // namespace transformer_engine
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include <cudaTypedefs.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
...@@ -22,10 +23,29 @@ ...@@ -22,10 +23,29 @@
#include <vector> #include <vector>
#include "./nvtx.h" #include "./nvtx.h"
#include "./util/cuda_driver.h"
#include "./util/logging.h" #include "./util/logging.h"
namespace transformer_engine { namespace transformer_engine {
inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries");
size_t ret = 1;
for (size_t i = begin; i < end; ++i) {
ret *= shape[i];
}
return ret;
}
inline size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (const auto &elem : shape) {
ret *= elem;
}
return ret;
}
struct SimpleTensor { struct SimpleTensor {
void *dptr; void *dptr;
std::vector<size_t> shape; std::vector<size_t> shape;
...@@ -33,20 +53,114 @@ struct SimpleTensor { ...@@ -33,20 +53,114 @@ struct SimpleTensor {
SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype) SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype)
: dptr(dptr), shape(shape), dtype(dtype) {} : dptr(dptr), shape(shape), dtype(dtype) {}
SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT
: dptr(tensor.data_ptr),
shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim),
dtype(static_cast<DType>(tensor.dtype)) {}
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
operator NVTEBasicTensor() const {
const NVTEShape shape = {this->shape.data(), this->shape.size()};
return {dptr, static_cast<NVTEDType>(dtype), shape};
}
int numel() const {
size_t acc = 1;
for (const auto &dim : shape) {
acc *= dim;
}
return acc;
}
}; };
struct Tensor { struct Tensor {
SimpleTensor data; SimpleTensor data;
SimpleTensor columnwise_data;
SimpleTensor amax; SimpleTensor amax;
SimpleTensor scale; SimpleTensor scale;
SimpleTensor scale_inv; SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv;
NVTEScalingMode scaling_mode;
Tensor() Tensor()
: data(), : data(),
columnwise_data(),
amax(nullptr, {1}, DType::kFloat32), amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32) {} scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const {
NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr,
"Tensor does not hold any data!");
size_t acc = 1;
if (data.dptr != nullptr) {
for (const auto &dim : data.shape) {
acc *= dim;
}
return acc;
}
// data is empty, use columnwise_data
for (const auto &dim : columnwise_data.shape) {
acc *= dim;
}
return acc;
}
bool has_data() const noexcept { return data.dptr != nullptr; }
bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; }
DType dtype() const {
if (has_data()) return data.dtype;
if (has_columnwise_data()) return columnwise_data.dtype;
// Fallback, used e.g. in workspace
return data.dtype;
}
/*! Matrix height after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_first_dim() const {
if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
return product(data_shape, 1, data_shape.size());
} else {
return product(data_shape, 0, data_shape.size() - 1);
}
}
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return product(data_shape, 0, data_shape.size() - 1);
}
/*! Matrix width after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_last_dim() const {
if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
return data_shape.front();
} else {
return data_shape.back();
}
}
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return data_shape.back();
}
}; };
template <typename T> template <typename T>
...@@ -62,6 +176,10 @@ using fp16 = half; ...@@ -62,6 +176,10 @@ using fp16 = half;
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
using e8m0_t = uint8_t;
namespace detail { namespace detail {
...@@ -80,6 +198,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(half) ...@@ -80,6 +198,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
#undef TRANSFORMER_ENGINE_TYPE_NAME #undef TRANSFORMER_ENGINE_TYPE_NAME
} // namespace detail } // namespace detail
...@@ -150,6 +271,10 @@ struct TypeInfo { ...@@ -150,6 +271,10 @@ struct TypeInfo {
using type = fp8e5m2; \ using type = fp8e5m2; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} break; \ } break; \
case DType::kFloat8E8M0: { \
using type = byte; \
{ __VA_ARGS__ } \
} break; \
default: \ default: \
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
...@@ -181,6 +306,25 @@ struct TypeInfo { ...@@ -181,6 +306,25 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
...@@ -236,15 +380,22 @@ struct TypeInfo { ...@@ -236,15 +380,22 @@ struct TypeInfo {
NVTE_ERROR("Invalid type for 16 bit."); \ NVTE_ERROR("Invalid type for 16 bit."); \
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// #define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \
switch (SCALE_DIM) { \
inline size_t product(const std::vector<size_t> &shape) { case 1: { \
size_t ret = 1; constexpr size_t DIM = 1; \
for (const auto &elem : shape) { { __VA_ARGS__ } \
ret *= elem; } break; \
case 32: { \
constexpr size_t DIM = 32; \
{ __VA_ARGS__ } \
} break; \
default: { \
NVTE_ERROR("Invalid size of the MX scaling factor."); \
} \
} }
return ret;
} ////////////////////////////////////////////////////////////////////////////////////////////////////
inline int log2_ceil(int value) { inline int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
...@@ -269,13 +420,37 @@ struct is_fp8<fp8e4m3> : std::true_type {}; ...@@ -269,13 +420,37 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template <> template <>
struct is_fp8<fp8e5m2> : std::true_type {}; struct is_fp8<fp8e5m2> : std::true_type {};
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
size_t typeToSize(const DType type); size_t typeToSize(const DType type);
void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
bool is_fp8_dtype(const DType t); bool is_fp8_dtype(const DType t);
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &type);
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_block_scaling(const NVTEScalingMode &mode) {
return mode != NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
return is_tensor_scaling(mode);
}
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
/*! \brief Update a tensor's FP8 scale-inverse /*! \brief Update a tensor's FP8 scale-inverse
* *
* The FP8 scale-inverse (dequantization scaling factor) is updated * The FP8 scale-inverse (dequantization scaling factor) is updated
...@@ -286,6 +461,20 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); ...@@ -286,6 +461,20 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);
#define NVTE_API_CALL(api_name) \ #define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
void checkCuDriverContext(CUstream stream);
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
inline bool isPointerAligned(const void *const ptr, const int alignment);
// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size);
bool is_supported_by_CC_100();
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
...@@ -93,17 +93,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -93,17 +93,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const bool supported_ragged_offset_size = const bool supported_ragged_offset_size =
(!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500);
if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) &&
(sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && // 8.9: t3hd, max_s=512, d=64, padding
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) && ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 &&
(head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv &&
((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 &&
(max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
// 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal}
(cudnn_runtime_version >= 90700 &&
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset) { !requires_64bit_ragged_offset) {
if (cudnn_runtime_version >= 8900) { if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8; backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
...@@ -135,8 +149,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -135,8 +149,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
!requires_64bit_ragged_offset) { !requires_64bit_ragged_offset) {
flag_m512 = true; flag_m512 = true;
} }
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging if (
if ( // architecture // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// special conditions for blackwell
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
!(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
// architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
// sequence length // sequence length
...@@ -218,9 +236,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -218,9 +236,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(cudnn_runtime_version >= 90600 && (cudnn_runtime_version >= 90600 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
// TODO(cyang): fix bug for BRCM + cross-attention on sm100
(sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700)))) ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
(sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700))))) &&
max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)))) && dropout == 0.0)))) &&
// check 64-bit ragged offset support // check 64-bit ragged offset support
......
...@@ -227,7 +227,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -227,7 +227,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) { if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_options.set_sliding_window_length(window_size_left + 1); sdpa_options.set_diagonal_band_left_bound(window_size_left + 1);
} }
sdpa_options.set_alibi_mask(is_alibi); sdpa_options.set_alibi_mask(is_alibi);
...@@ -457,8 +457,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -457,8 +457,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
// keep original batch size because cu_seqlens are created with [b+1] shape // keep original batch size because cu_seqlens are created with [b+1] shape
int64_t actual_b = b; int64_t actual_b = b;
if (is_ragged && cudnn_runtime_version >= 90600) { if (is_ragged && cudnn_runtime_version >= 90600) {
...@@ -667,7 +665,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -667,7 +665,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
} }
if (cudnn_runtime_version >= 90200 && window_size_left != -1) { if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_sliding_window_length(window_size_left + 1); sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1);
} }
if (cudnn_runtime_version >= 90000) { if (cudnn_runtime_version >= 90000) {
......
...@@ -1670,8 +1670,6 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1670,8 +1670,6 @@ void fused_attn_fp8_fwd_impl_v1(
auto bias_h = h; auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!");
NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!");
try { try {
FADescriptor_v1 descriptor{b, FADescriptor_v1 descriptor{b,
...@@ -1798,36 +1796,33 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1798,36 +1796,33 @@ void fused_attn_fp8_fwd_impl_v1(
// sdpa_options.set_bias(bias); // sdpa_options.set_bias(bias);
// } // }
// if (is_padding) { if (is_padding) {
// seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_q") .set_name("seq_q")
// .set_dim({b, 1, 1, 1}) .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32)); .set_data_type(fe::DataType_t::INT32));
// seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_kv") .set_name("seq_kv")
// .set_dim({b, 1, 1, 1}) .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32)); .set_data_type(fe::DataType_t::INT32));
// sdpa_options.set_padding_mask(is_padding) sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv);
// .set_seq_len_q(seq_q) }
// .set_seq_len_kv(seq_kv);
// }
// if (is_dropout) { if (is_dropout) {
// dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Seed") .set_name("Seed")
// .set_dim({1, 1, 1, 1}) .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64)); .set_data_type(fe::DataType_t::INT64));
// dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Offset") .set_name("Offset")
// .set_dim({1, 1, 1, 1}) .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64)); .set_data_type(fe::DataType_t::INT64));
// sdpa_options.set_dropout( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
// dropout_probability, dropout_seed, dropout_offset); }
// }
auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8(
Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options);
...@@ -1919,29 +1914,28 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1919,29 +1914,28 @@ void fused_attn_fp8_fwd_impl_v1(
{amax_o, devPtrAmaxO}, {amax_o, devPtrAmaxO},
{Stats, devPtrM}}; {Stats, devPtrM}};
// if (is_bias) { /* if (is_bias) {
// variant_pack[bias] = devPtrBias; variant_pack[bias] = devPtrBias;
// } } */
// if (is_padding) { if (is_padding) {
// constexpr size_t nthreads_per_block = 128; constexpr size_t nthreads_per_block = 128;
// const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
// void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size; void* devActualSeqlenQ = static_cast<int8_t*>(workspace) + plan_workspace_size;
// void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) void* devActualSeqlenKV = static_cast<int8_t*>(devActualSeqlenQ) + b * sizeof(int32_t);
// + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
// cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>( b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
// b, static_cast<const int32_t *>(devPtrCuSeqlensQ), static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
// static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t*>(devActualSeqlenKV));
// static_cast<int32_t *>(devActualSeqlenQ), variant_pack[seq_q] = devActualSeqlenQ;
// static_cast<int32_t *>(devActualSeqlenKV)); variant_pack[seq_kv] = devActualSeqlenKV;
// variant_pack[seq_q] = devActualSeqlenQ; }
// variant_pack[seq_kv] = devActualSeqlenKV;
// } if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
// if (is_dropout) { variant_pack[dropout_offset] = devPtrDropoutOffset;
// variant_pack[dropout_seed] = devPtrDropoutSeed; }
// variant_pack[dropout_offset] = devPtrDropoutOffset;
// }
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException& e) { } catch (cudnn_frontend::cudnnException& e) {
NVTE_ERROR(e.what()); NVTE_ERROR(e.what());
...@@ -1974,8 +1968,6 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -1974,8 +1968,6 @@ void fused_attn_fp8_bwd_impl_v1(
auto bias_h = h; auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!");
NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!");
try { try {
FADescriptor_v1 descriptor{b, FADescriptor_v1 descriptor{b,
...@@ -2151,36 +2143,35 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2151,36 +2143,35 @@ void fused_attn_fp8_bwd_impl_v1(
// } // }
// } // }
// if (is_padding) { if (is_padding) {
// seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_q") .set_name("seq_q")
// .set_dim({b, 1, 1, 1}) .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32)); .set_data_type(fe::DataType_t::INT32));
// seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_kv") .set_name("seq_kv")
// .set_dim({b, 1, 1, 1}) .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32)); .set_data_type(fe::DataType_t::INT32));
// sdpa_backward_options.set_padding_mask(is_padding) sdpa_backward_options.set_padding_mask(is_padding)
// .set_seq_len_q(seq_q) .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv); .set_seq_len_kv(seq_kv);
// } }
// if (is_dropout) { if (is_dropout) {
// dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Seed") .set_name("Seed")
// .set_dim({1, 1, 1, 1}) .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64)); .set_data_type(fe::DataType_t::INT64));
// dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Offset") .set_name("Offset")
// .set_dim({1, 1, 1, 1}) .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64)); .set_data_type(fe::DataType_t::INT64));
// sdpa_backward_options.set_dropout( sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
// dropout_probability, dropout_seed, dropout_offset); }
// }
auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward(
q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s,
...@@ -2308,34 +2299,32 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2308,34 +2299,32 @@ void fused_attn_fp8_bwd_impl_v1(
{amax_dP, devPtrAmaxdP}, {amax_dP, devPtrAmaxdP},
}; };
// if (is_bias) { /* if (is_bias) {
// variant_pack[bias] = devPtrBias; variant_pack[bias] = devPtrBias;
// if ((bias_b == 1) && (bias_h == h)) { if ((bias_b == 1) && (bias_h == h)) {
// variant_pack[dBias] = devPtrdBias; variant_pack[dBias] = devPtrdBias;
// } else { } else {
// variant_pack[dBias] = nullptr; variant_pack[dBias] = nullptr;
// } }
// } } */
// if (is_padding) { if (is_padding) {
// constexpr size_t nthreads_per_block = 128; constexpr size_t nthreads_per_block = 128;
// const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
// void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size; void* devActualSeqlenQ = static_cast<int8_t*>(workspace) + plan_workspace_size;
// void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) void* devActualSeqlenKV = static_cast<int8_t*>(devActualSeqlenQ) + b * sizeof(int32_t);
// + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
// cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>( b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
// b, static_cast<const int32_t *>(devPtrCuSeqlensQ), static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
// static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t*>(devActualSeqlenKV));
// static_cast<int32_t *>(devActualSeqlenQ), variant_pack[seq_q] = devActualSeqlenQ;
// static_cast<int32_t *>(devActualSeqlenKV)); variant_pack[seq_kv] = devActualSeqlenKV;
// variant_pack[seq_q] = devActualSeqlenQ; }
// variant_pack[seq_kv] = devActualSeqlenKV;
// } if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
// if (is_dropout) { variant_pack[dropout_offset] = devPtrDropoutOffset;
// variant_pack[dropout_seed] = devPtrDropoutSeed; }
// variant_pack[dropout_offset] = devPtrDropoutOffset;
// }
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException& e) { } catch (cudnn_frontend::cudnnException& e) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "../common.h" #include "../common.h"
#include "../util/logging.h" #include "../util/logging.h"
#include "common/util/cuda_runtime.h"
namespace { namespace {
...@@ -46,6 +47,95 @@ uint32_t _getAlignment(uintptr_t address) { ...@@ -46,6 +47,95 @@ uint32_t _getAlignment(uintptr_t address) {
} }
} }
struct GemmParam {
void *A;
void *B;
cublasOperation_t transA;
cublasOperation_t transB;
transformer_engine::DType Atype;
transformer_engine::DType Btype;
void *A_scale_inv;
void *B_scale_inv;
int lda;
int ldb;
GemmParam(cublasOperation_t transA, cublasOperation_t transB)
: A(nullptr),
B(nullptr),
transA(transA),
transB(transB),
Atype(transformer_engine::DType::kNumTypes),
Btype(transformer_engine::DType::kNumTypes),
A_scale_inv(nullptr),
B_scale_inv(nullptr),
lda(0),
ldb(0) {}
};
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
const transformer_engine::Tensor &B, const cublasOperation_t transB,
const int k, const int lda, const int ldb) {
using namespace transformer_engine;
NVTE_CHECK(A.scaling_mode == B.scaling_mode,
"Inputs A and B to GEMM need to have the same scaling mode!");
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret(transA, transB);
ret.lda = lda;
ret.ldb = ldb;
if (is_tensor_scaling(A.scaling_mode)) {
ret.A = A.data.dptr;
ret.A_scale_inv = A.scale_inv.dptr;
if (transA == CUBLAS_OP_T) {
ret.Atype = A.data.dtype;
} else {
ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype;
if (is_fp8_dtype(ret.Atype)) {
int arch = cuda::sm_arch(cuda::current_device());
if (arch < 100) {
// Hopper and Ada - we need to use columnwise_data and change transA
NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!");
ret.A = A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T;
ret.A_scale_inv = A.columnwise_scale_inv.dptr;
ret.lda = k;
}
}
}
ret.B = B.data.dptr;
ret.B_scale_inv = B.scale_inv.dptr;
if (transB == CUBLAS_OP_T) {
ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype;
if (is_fp8_dtype(ret.Btype)) {
int arch = cuda::sm_arch(cuda::current_device());
if (arch < 100) {
// Hopper and Ada - we need to use columnwise_data and change transA
NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!");
ret.B = B.columnwise_data.dptr;
ret.transB = CUBLAS_OP_N;
ret.B_scale_inv = B.columnwise_scale_inv.dptr;
ret.ldb = k;
}
}
} else {
ret.Btype = B.data.dtype;
}
} else {
// If not tensor scaling (which includes also high precision types), we need to
// use the proper version of data
// We leave the transA/B values as is, since Blackwell supports transposes
ret.A = transA ? A.data.dptr : A.columnwise_data.dptr;
ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.B = transB ? B.columnwise_data.dptr : B.data.dptr;
ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
}
return ret;
}
} // namespace } // namespace
namespace transformer_engine { namespace transformer_engine {
...@@ -56,10 +146,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -56,10 +146,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer, int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, cudaStream_t stream) { const Tensor *inputCounter, cudaStream_t stream) {
void *A = inputA->data.dptr; // Return immediately if GEMM is trivial
void *A_scale_inverse = inputA->scale_inv.dptr; if (m <= 0 || n <= 0) {
void *B = inputB->data.dptr; return;
void *B_scale_inverse = inputB->scale_inv.dptr; }
NVTE_CHECK(k > 0);
const GemmParam &param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb);
void *C = outputD->data.dptr; void *C = outputD->data.dptr;
void *D = outputD->data.dptr; void *D = outputD->data.dptr;
void *D_scale = outputD->scale.dptr; void *D_scale = outputD->scale.dptr;
...@@ -72,15 +165,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -72,15 +165,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
counter = inputCounter->data.dptr; counter = inputCounter->data.dptr;
} }
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype);
const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype); const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype); const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype); const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr, NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
// check consistency of arguments: // check consistency of arguments:
...@@ -117,17 +211,17 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -117,17 +211,17 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
} }
// Create matrix descriptors. Not setting any extra attributes. // Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, transa == CUBLAS_OP_N ? m : k, NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
transa == CUBLAS_OP_N ? k : m, lda)); param.transA == CUBLAS_OP_N ? k : m, param.lda));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, transb == CUBLAS_OP_N ? k : n, NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
transb == CUBLAS_OP_N ? n : k, ldb)); param.transB == CUBLAS_OP_N ? n : k, param.ldb));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa))); &param.transA, sizeof(param.transA)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb))); &param.transB, sizeof(param.transB)));
// Set math SM count // Set math SM count
if (math_sm_count != 0) { if (math_sm_count != 0) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
...@@ -143,12 +237,53 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -143,12 +237,53 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode, sizeof(fastAccuMode))); &fastAccuMode, sizeof(fastAccuMode)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, // Scaling factors.
&A_scale_inverse, sizeof(A_scale_inverse))); #if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, cublasLtMatmulMatrixScale_t scaling_mode;
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, #endif
&B_scale_inverse, sizeof(B_scale_inverse))); if ((is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode))) {
void *A_scale_inverse = param.A_scale_inv;
void *B_scale_inverse = param.B_scale_inv;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080
scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
} else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) {
fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublasLtGetVersion() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride)));
}
#endif
} else {
NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " +
to_string(inputB->scaling_mode) + ".");
}
#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
#endif
if (is_fp8_dtype(outputD->data.dtype)) { if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output // Accumulation mode not supported for FP8 output
C = nullptr; C = nullptr;
...@@ -156,8 +291,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -156,8 +291,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale))); operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
// For FP8 output, cuBLAS requires C_type to be same as bias_type #if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
#endif
// For FP8 output, cuBLAS requires C_type to match bias_type and
// be FP16/BF16
const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
} else { } else {
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
} }
...@@ -235,8 +376,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -235,8 +376,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(A)); const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(B)); const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C)); const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D)); const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
...@@ -260,8 +401,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -260,8 +401,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
static_cast<const void *>(&one), /* alpha */ static_cast<const void *>(&one), /* alpha */
A, /* A */ param.A, /* A */
Adesc, B, /* B */ Adesc, param.B, /* B */
Bdesc, static_cast<const void *>(&beta), /* beta */ Bdesc, static_cast<const void *>(&beta), /* beta */
C, /* C */ C, /* C */
Cdesc, D, /* D */ Cdesc, D, /* D */
...@@ -270,7 +411,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -270,7 +411,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
workspaceSize, stream)); /* stream */ workspaceSize, stream)); /* stream */
// Update FP8 scale-inv in output tensor // Update FP8 scale-inv in output tensor
if (is_fp8_dtype(outputD->data.dtype)) { // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
// TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
// calculated here.
if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
update_tensor_scale_inv(outputD, stream); update_tensor_scale_inv(outputD, stream);
} }
...@@ -309,9 +453,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -309,9 +453,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out); Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; const size_t A0 = inputA->flat_first_dim();
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; const size_t A1 = inputA->flat_last_dim();
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; const size_t B0 = inputB->flat_first_dim();
const size_t B1 = inputB->flat_last_dim();
const int m = transa ? A0 : A1;
const int k = transa ? A1 : A0;
const int n = transb ? B1 : B0;
int lda, ldb, ldd; int lda, ldb, ldd;
if (transa && !transb) { // TN if (transa && !transb) { // TN
lda = k; lda = k;
...@@ -357,6 +506,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -357,6 +506,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter); const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode),
"Atomic GEMM only supports delayed scaling.");
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
......
...@@ -19,7 +19,9 @@ extern "C" { ...@@ -19,7 +19,9 @@ extern "C" {
/* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ /* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */
/*! \brief Compute activation of the input. /*! \brief Computes activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
* *
* \param[in] input Input tensor for activation. * \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor. * \param[in,out] output Output tensor.
...@@ -39,17 +41,59 @@ enum class NVTE_Activation_Type { ...@@ -39,17 +41,59 @@ enum class NVTE_Activation_Type {
SREGLU, SREGLU,
}; };
/*! \brief Computes the GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Compute activation gradient. /*! \brief Computes the GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
* *
* \param[in] grad Incoming gradient. * \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation. * \param[in] input Input tensor for activation.
...@@ -59,19 +103,57 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); ...@@ -59,19 +103,57 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the SiLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute gated activation of the input. /*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
* *
* \param[in] input Input tensor of shape [N, H * 2]. * \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H]. * \param[in,out] output Output tensor of shape [N, H].
...@@ -80,15 +162,54 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu ...@@ -80,15 +162,54 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
*/ */
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Swish activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Quick GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Squared ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Compute gated activation gradient. /*! \brief Computes the gated GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H]. * \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2]. * \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2]. * \param[in,out] output Outgoing gradient of shape [N, H * 2].
...@@ -97,15 +218,51 @@ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) ...@@ -97,15 +218,51 @@ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the gated Swish activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the gated Quick GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the gated Squared ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
************************************************************************/ ************************************************************************/
/*! \file cast.h /*! \file cast.h
* \brief Functions to cast to/from FP8. * \brief Functions to cast to/from FP8/MXFP8.
*/ */
#ifndef TRANSFORMER_ENGINE_CAST_H_ #ifndef TRANSFORMER_ENGINE_CAST_H_
...@@ -17,21 +17,200 @@ ...@@ -17,21 +17,200 @@
extern "C" { extern "C" {
#endif #endif
/*! \brief Cast tensor to FP8. /* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer)
* The implementation is per the microscaling format MXFP8 defined by the OCP specification:
* https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
* *
* \param[in] input Input tensor to be cast. * Supported modes of scaling (live scaling):
* \param[in,out] output Output FP8 tensor. * 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes:
* \param[in] stream CUDA stream used for the operation. * - the scaled output tensor
* - the corresponding scaling factors
* The scaling factors are computed for blocks of the shape [1,32]
* (i.e., each scaling factor spans 32 contiguous elements along rows).
*
* 2) Columwise scaling (along the dim=1) computes one set of the output data.
* The scaling factors are computed for blocks of the shape [32,1]
* (i.e., each scaling factor spans 32 contiguous elements along columns).
*
* 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1)
* computes two sets of the output data: both 1) and 2).
*
* The shape of the MX block must be specified in the 'output' argument,
* and can be either [1,32] or [32,1] as no other shapes are currently supported.
*
* To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter
* of the output tensor should be set to 0.
*/
/*! \brief Casts input tensor to FP8/MXFP8.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] noop Noop tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream);
/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workplace, cudaStream_t stream);
/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Cast tensor from FP8. /*! \brief Casts input tensor from reduced to higher precision.
* If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING,
* the block dequantization (MXFP8) of the specified shape of the block will be used.
* In case of the MXFP8 dequantization, the dequantized values are stored to the rowwise
* data of the output tensor, regardless of whether the row- or columnwise scaling is used.
* *
* \param[in] input Input tensor to be cast. * \param[in] input Input FP8/MXFP8 tensor to be cast.
* \param[out] output Output tensor. * \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -17,11 +17,26 @@ ...@@ -17,11 +17,26 @@
extern "C" { extern "C" {
#endif #endif
/*! \brief Transposes the input, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
*
* \param[in] input Input tensor.
* \param[in] noop Noop tensor.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, /*! \brief Casts and transposes the input, providing the option to immediately exit the kernel
NVTETensor cast_output, NVTETensor transposed_output, * based on the value of the 'noop' tensor.
*
* \param[in] input Input tensor.
* \param[in] noop Noop tensor.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -53,6 +53,8 @@ class CommOverlapCore { ...@@ -53,6 +53,8 @@ class CommOverlapCore {
int _cga_size; int _cga_size;
int _use_ce; int _use_ce;
int _ub_reg; int _ub_reg;
int _gemm_priority;
int _comm_priority;
bool _atomic_gemm{false}; bool _atomic_gemm{false};
bool _is_p2p{false}; bool _is_p2p{false};
...@@ -65,10 +67,13 @@ class CommOverlapCore { ...@@ -65,10 +67,13 @@ class CommOverlapCore {
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;
public: public:
CommOverlapCore() {} // dummy constructor for exposing type to Python
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority,
bool set_sm_margin, bool use_ce, bool atomic_gemm); int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm);
virtual ~CommOverlapCore(); virtual ~CommOverlapCore();
...@@ -77,25 +82,76 @@ class CommOverlapCore { ...@@ -77,25 +82,76 @@ class CommOverlapCore {
_ubuf_scale_inv_initialized = true; _ubuf_scale_inv_initialized = true;
} }
TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
bool is_atomic_gemm() { return _atomic_gemm; } bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; } bool is_p2p_overlap() { return _is_p2p; }
bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } bool is_fp8_ubuf() { return _ubuf.element_size() == 1; }
virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, CommOverlapType comm_type,
TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
}; // CommOverlapCore }; // CommOverlapCore
class CommOverlapBase : public CommOverlapCore { class CommOverlapBase : public CommOverlapCore {
protected: protected:
int _rs_kernel_type; int _rs_kernel_type;
bool _rs_overlap_first_gemm;
cudaStream_t _stream_comm; cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy; cudaEvent_t _start_d2dcopy;
public: public:
CommOverlapBase() {} // dummy constructor for exposing type to Python
CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank, CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
bool set_sm_margin = true, bool atomic_gemm = false,
bool rs_overlap_first_gemm = false);
virtual ~CommOverlapBase(); virtual ~CommOverlapBase();
...@@ -103,97 +159,124 @@ class CommOverlapBase : public CommOverlapCore { ...@@ -103,97 +159,124 @@ class CommOverlapBase : public CommOverlapCore {
** Bulk GEMM + COMM ** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf ** This function assumes the communication input is pre-copied to _ubuf
*/ */
void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &workspace, bool grad, bool accumulate,
CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) override;
void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
/* /*
** Split FPROP GEMM + ReduceScatter ** Split FPROP GEMM + ReduceScatter
*/ */
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool use_split_accumulator, bool gemm_overlap, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
TensorWrapper &rs_output, cudaStream_t stream_main); cudaStream_t stream_main) override;
/* /*
** Split FPROP GEMM + ReduceScatter ** Split FPROP GEMM + ReduceScatter
*/ */
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main); cudaStream_t stream_main) override;
}; // CommOverlapBase }; // CommOverlapBase
class CommOverlapP2PBase : public CommOverlapCore { class CommOverlapP2PBase : public CommOverlapCore {
protected: protected:
bool _is_reduce_scatter{false}; bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false}; bool _use_multiatomic_ag{false};
bool _aggregate;
int _next_rank; int _next_rank;
int _prev_rank; int _prev_rank;
int _rank_round_tp; int _rank_round_tp;
int _aggregate;
int _num_ubuf_chunks; int _num_ubuf_chunks;
int _self_chunk_id; int _self_chunk_id;
std::vector<TensorWrapper> _ubufs; std::vector<TensorWrapper> _ubufs;
std::vector<cudaStream_t> _stream_send;
cudaStream_t _stream_send;
cudaStream_t _stream_recv; cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv; cudaEvent_t _stop_send, _stop_recv;
public: public:
CommOverlapP2PBase() {} // dummy constructor for exposing type to Python
CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank, CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS,
int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false, int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0,
bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true,
bool atomic_gemm = false, bool aggregate = false);
virtual ~CommOverlapP2PBase(); virtual ~CommOverlapP2PBase();
TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id);
void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
/* /*
** Split AllGather + AtomicGEMM using P2P communication ** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases. ** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/ */
void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool use_split_accumulator, TensorWrapper &B_copy, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main); cudaStream_t stream_main) override;
/* /*
** Split AllGather + GEMM using P2P communication ** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases. ** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/ */
void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main); cudaStream_t stream_main) override;
/* /*
** Split ReduceScatter + GEMM using P2P communication ** Split ReduceScatter + GEMM using P2P communication
*/ */
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool use_split_accumulator, TensorWrapper &rs_output, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main); cudaStream_t stream_main) override;
/* /*
** Split ReduceScatter + GEMM using P2P communication ** Split ReduceScatter + GEMM using P2P communication
*/ */
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main); cudaStream_t stream_main) override;
}; // CommOverlapP2PBase }; // CommOverlapP2PBase
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -28,16 +28,10 @@ extern "C" { ...@@ -28,16 +28,10 @@ extern "C" {
* \param[in] amax_history History of maximum absolute values. * \param[in] amax_history History of maximum absolute values.
* Shape: [history_length, num_scales] * Shape: [history_length, num_scales]
* \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales] * \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales]
* \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales]
* \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be
* empty, in which case all scale_inv entries are updated.
* Shape: [num_scales]
* \param[out] updated_amax_history Updated history of maximum absolute values. * \param[out] updated_amax_history Updated history of maximum absolute values.
* Shape: [history_length, num_scales] * Shape: [history_length, num_scales]
* \param[out] updated_scale Updated scaling factor for casting to FP8. * \param[out] updated_scale Updated scaling factor for casting to FP8.
* Shape: [num_scales] * Shape: [num_scales]
* \param[out] updated_scale_inv Updated scaling factor for casting from FP8.
* Shape: [num_scales]
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and
* "most_recent". * "most_recent".
* \param[in] fp8_dtype FP8 datatype. * \param[in] fp8_dtype FP8 datatype.
...@@ -45,9 +39,8 @@ extern "C" { ...@@ -45,9 +39,8 @@ extern "C" {
* \param[in] stream CUDA stream. * \param[in] stream CUDA stream.
*/ */
void nvte_delayed_scaling_recipe_amax_and_scale_update( void nvte_delayed_scaling_recipe_amax_and_scale_update(
const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history,
const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction.
...@@ -55,7 +48,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( ...@@ -55,7 +48,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
* Operations performed include, updating the most recent amax history * Operations performed include, updating the most recent amax history
* with the relevant segment of global reduction buffer if it's not 0, * with the relevant segment of global reduction buffer if it's not 0,
* rotating the amax history based on the rule below, and updating the * rotating the amax history based on the rule below, and updating the
* scales and scale_invs. * scales.
* *
* The amax history is rotated by -1 (e.g. the first entry shifts to * The amax history is rotated by -1 (e.g. the first entry shifts to
* the last, the last entry shifts to the second to last) and the * the last, the last entry shifts to the second to last) and the
...@@ -69,8 +62,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( ...@@ -69,8 +62,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
* Shape: num_tensors x [history_length, num_scales] * Shape: num_tensors x [history_length, num_scales]
* \param[in,out] scales List of scaling factors for casting to FP8. * \param[in,out] scales List of scaling factors for casting to FP8.
* Shape: num_tensors x [num_scales] * Shape: num_tensors x [num_scales]
* \param[in,out] scale_invs List of scaling factors for casting from FP8.
* Shape: num_tensors x [num_scales]
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and
* "most_recent". * "most_recent".
* \param[in] fp8_dtype FP8 datatype. * \param[in] fp8_dtype FP8 datatype.
...@@ -79,8 +70,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( ...@@ -79,8 +70,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
*/ */
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories, const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales, std::vector<NVTETensor> scale_invs, std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); float margin, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file cast.h
* \brief Functions to cast to/from FP8.
*/
#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_
#define TRANSFORMER_ENGINE_SWIZZLE_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM
*
* \param[in] input Input tensor with non-swizzled scale_inv.
* \param[in,out] output Output tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_SWIZZLE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment