Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
......@@ -19,19 +19,9 @@ try:
except (ImportError, StopIteration) as e:
pass
try:
from . import paddle
except (ImportError, StopIteration) as e:
pass
try:
import transformer_engine_jax
except ImportError:
pass
try:
import transformer_engine_paddle
except ImportError:
pass
__version__ = str(metadata.version("transformer_engine"))
......@@ -6,13 +6,17 @@ cmake_minimum_required(VERSION 3.21)
# Language options
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120)
else ()
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G")
endif()
# Hide non-necessary symbols in shared object.
......@@ -78,6 +82,7 @@ list(APPEND transformer_engine_SOURCES
util/cuda_runtime.cpp
util/rtc.cpp
util/system.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
......
......@@ -4,111 +4,71 @@
* See LICENSE for license information.
************************************************************************/
/*! \file activation_template.h
* \brief Activation functions template.
*/
#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include "../common.h"
#include "../util/cast_gated_kernels.cuh"
#include "../util/cast_kernels.cuh"
#include "../util/math.h"
#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "act_lu_input");
CheckOutputTensor(*output, "act_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = true;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
workspace, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "dact_lu_input");
CheckInputTensor(grad, "dact_lu_input_grad");
CheckOutputTensor(*output, "dact_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Param, OP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), tot_elts, {},
stream);); // NOLINT(*)
); // NOLINT(*)
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
workspace, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, ComputeType, Param, OP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const ComputeType *>(output->scale.dptr),
reinterpret_cast<ComputeType *>(output->amax.dptr),
reinterpret_cast<ComputeType *>(output->scale_inv.dptr), output->data.shape[0],
output->data.shape[1], {},
stream);); // NOLINT(*)
); // NOLINT(*)
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP1)(ComputeType, const Param &),
ComputeType (*OP2)(ComputeType, const Param &)>
void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, ComputeType, Param, OP1, OP2>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), grad.data.shape[0], grad.data.shape[1],
{},
stream);); // NOLINT(*)
); // NOLINT(*)
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_
......@@ -3,69 +3,58 @@
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine;
act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine;
act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, stream);
}
......@@ -10,63 +10,51 @@
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_relu);
using namespace transformer_engine;
act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
}
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, drelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
}
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_srelu);
using namespace transformer_engine;
act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
}
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
}
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, stream);
}
......@@ -10,31 +10,25 @@
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_silu);
using namespace transformer_engine;
act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
}
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
}
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(
*reinterpret_cast<const Tensor*>(grad), *reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output), stream);
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, stream);
}
......@@ -21,6 +21,8 @@
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using namespace std::placeholders;
namespace transformer_engine {
......@@ -40,8 +42,9 @@ bool ubuf_built_with_mpi() {
CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm) {
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm) {
// Initialize userbuf communicator
if (!_comm_created) {
if (myrank == 0) {
......@@ -59,9 +62,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
if (gemm_priority == 0 && comm_priority == 0) {
transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority);
} else {
_gemm_priority = gemm_priority;
_comm_priority = comm_priority;
}
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority));
_stream_compute.push_back(std::move(stream));
}
......@@ -130,6 +139,73 @@ CommOverlapCore::~CommOverlapCore() {
}
}
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk;
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type);
auto param_dptr = reinterpret_cast<char *>(param.data_ptr);
auto param_dtype = static_cast<DType>(param.dtype);
auto param_shape = AS_VECTOR(param.shape);
if (param_dptr != nullptr) {
if (param_type == NVTETensorParam::kNVTERowwiseData ||
param_type == NVTETensorParam::kNVTEColumnwiseData) {
// Offset data pointer
param_dptr += chunk_offset * typeToSize(param_dtype);
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back();
param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim);
}
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? source.shape().data[0]
: source.columnwise_shape().data[0];
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? chunk_shape.front()
: chunk_shape.back();
auto chunk_scale_start = chunk_offset / 32;
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32;
auto chunk_scale_size = chunk_scale_end - chunk_scale_start;
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
}
// Set chunked source parameters into the chunked tensor output
chunk.set_parameter(param_type, reinterpret_cast<void *>(param_dptr), param_dtype,
param_shape);
}
}
return chunk;
}
TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source,
size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
// Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape);
// Update chunk with offset data pointers from the communication buffer
auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size());
if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(), chunk.shape());
}
if (chunk.columnwise_dptr() != nullptr) {
chunk.set_columnwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(),
chunk.columnwise_shape());
}
return chunk;
}
/***************************************************************************************************
* Comm+GEMM Overlap Base (Pipelined / Collective)
**************************************************************************************************/
......@@ -138,11 +214,14 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool atomic_gemm)
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool atomic_gemm,
bool rs_overlap_first_gemm)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, false, atomic_gemm) {
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) {
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
"Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ",
......@@ -155,7 +234,8 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
}
......@@ -168,8 +248,8 @@ CommOverlapBase::~CommOverlapBase() {
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias,
void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output,
......@@ -196,7 +276,7 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
......@@ -221,20 +301,20 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output,
void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa,
const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions
size_t m = A.size(0);
size_t k = A.size(1);
size_t n = B.size(0);
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
......@@ -255,9 +335,8 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
assert(pre_gelu_out.numel() == 0);
auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr);
auto workspace_chunk =
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
auto output_d = get_buffer_chunk_like(D, 0, {n, m});
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(),
......@@ -269,11 +348,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits,
rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits,
&counter_ptr[i], _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
......@@ -282,11 +360,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
}
} else if (_rs_kernel_type == 2) {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m,
......@@ -299,7 +376,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(rs_output_ptr, _ubuf_scale_inv,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(rs_output_ptr, D.scale_inv(),
_ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
......@@ -321,34 +398,24 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias,
void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main) {
TensorWrapper &rs_output, cudaStream_t stream_main) {
// Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
size_t m = A.size(0);
size_t k = A.size(1);
size_t n = B.size(0);
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits;
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t bias_chunk_size = m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *bias_chunk_ptr = reinterpret_cast<char *>(bias.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
......@@ -358,39 +425,23 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
assert(pre_gelu_out.numel() == 0);
if (gemm_overlap) {
auto input_a_chunk =
TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk =
TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto bias_chunk =
TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr);
auto workspace_chunk = TensorWrapper(
workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * D.element_size();
if (bias_chunk_ptr != nullptr) {
bias_chunk_ptr += bias_chunk_size * bias.element_size();
}
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
input_a_chunk = TensorWrapper(reinterpret_cast<void *>(input_a_chunk_ptr), {m_chunk, k},
A.dtype(), nullptr, nullptr, A.scale_inv());
output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr), {n, m_chunk},
D.dtype(), D.amax(), D.scale(), nullptr);
bias_chunk = TensorWrapper(reinterpret_cast<void *>(bias_chunk_ptr), {m_chunk}, bias.dtype(),
nullptr, nullptr, nullptr);
workspace_chunk = TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -401,11 +452,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
// Communication chunk
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m,
rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
......@@ -422,12 +472,11 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm););
rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk,
n, m, _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
......@@ -435,20 +484,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
}
} else {
for (int i = 0; i < _num_splits; i++) {
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto input_a_chunk = TensorWrapper(reinterpret_cast<void *>(input_a_chunk_ptr), {m_chunk, k},
A.dtype(), nullptr, nullptr, A.scale_inv());
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_buf_chunk_ptr),
{n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr);
auto bias_chunk = TensorWrapper(reinterpret_cast<void *>(bias_chunk_ptr), {m_chunk},
bias.dtype(), nullptr, nullptr, nullptr);
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(),
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
......@@ -461,11 +502,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m,
rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
......@@ -473,11 +513,6 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap
}
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
if (bias_chunk_ptr != nullptr) {
bias_chunk_ptr += bias_chunk_size * bias.element_size();
}
}
}
......@@ -499,11 +534,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams,
int comm_cga_size, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm, bool aggregate)
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm, bool aggregate)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
num_comm_sm, set_sm_margin, use_ce, atomic_gemm) {
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
......@@ -552,8 +589,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
}
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1));
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream));
}
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
}
......@@ -562,7 +604,22 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send);
cudaStreamDestroy(_stream_recv);
cudaStreamDestroy(_stream_send);
for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]);
}
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) {
// Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape()));
// Update chunk with offset data pointers from the communication buffer
if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape());
}
if (chunk.columnwise_dptr() != nullptr) {
chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape());
}
return chunk;
}
/*
......@@ -570,12 +627,10 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
void CommOverlapP2PBase::atomic_gemm_overlap_ag(
const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
......@@ -583,8 +638,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t n = _ubuf.size(0);
const size_t n_chunk = n / _tp_size;
const size_t n_chunk = _ubufs[0].size(0);
assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes
......@@ -594,7 +648,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
void *D_buffer_ptr;
int D_chunk_bytes = n_chunk * m * D.element_size();
NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main));
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(),
D.scale_inv(), D.scale_inv_shape(), D.scaling_mode());
// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
......@@ -602,13 +657,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv());
auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape()));
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.dptr(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
......@@ -649,8 +703,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
cudaMemcpyDeviceToDevice, _stream_send[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
......@@ -674,11 +728,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
......@@ -691,24 +746,20 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
const int output_chunk_bytes = (n_chunk * m) * D.element_size();
const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0;
// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.dptr());
char *pre_gelu_out_ptr = reinterpret_cast<char *>(pre_gelu_out.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
if (_aggregate) {
const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.dptr());
input_chunk_size *= 2;
output_chunk_size *= 2;
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
......@@ -717,11 +768,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
_stream_send);
_stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
......@@ -736,27 +787,15 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
recv_offset = comm_bytes * recv_chunk_id;
// GEMM
char *input_b_chunk_ptr = input_b_ptr + send_offset;
auto input_b_chunk =
TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(),
nullptr, nullptr, B.scale_inv());
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes);
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr),
{n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr);
char *aux_chunk_ptr =
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
auto aux_chunk_shape =
(do_gelu) ? std::vector<size_t>{n_chunk * 2, m} : std::vector<size_t>{0};
auto aux_chunk = TensorWrapper(reinterpret_cast<void *>(aux_chunk_ptr), aux_chunk_shape,
pre_gelu_out.dtype());
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m});
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
......@@ -766,11 +805,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
if (i < num_steps - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, _stream_send);
next_rank, _stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
......@@ -778,7 +817,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
cudaMemcpyDeviceToDevice, _stream_send[0]));
}
}
} else {
......@@ -793,24 +832,14 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(),
nullptr, nullptr, B.scale_inv());
char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes);
auto output_chunk = TensorWrapper(reinterpret_cast<void *>(output_chunk_ptr), {n_chunk, m},
D.dtype(), D.amax(), D.scale(), nullptr);
char *aux_chunk_ptr =
(do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr;
auto aux_chunk_shape = (do_gelu) ? std::vector<size_t>{n_chunk, m} : std::vector<size_t>{0};
auto aux_chunk = TensorWrapper(reinterpret_cast<void *>(aux_chunk_ptr), aux_chunk_shape,
pre_gelu_out.dtype());
char *workspace_chunk_ptr =
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m});
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
......@@ -820,11 +849,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
if (i < _tp_size - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, _stream_send);
_next_rank, _stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
......@@ -832,7 +861,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send));
cudaMemcpyDeviceToDevice, _stream_send[0]));
}
}
}
......@@ -842,7 +871,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
......@@ -851,12 +880,10 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output,
void CommOverlapP2PBase::atomic_gemm_overlap_rs(
const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
......@@ -876,14 +903,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr);
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk =
TensorWrapper(workspace.data(), std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape()));
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(),
stream_main);
transa, transb, grad, workspace.data(), accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, true, _counter.data(), stream_main);
// P2P communication chunk
for (int i = 1; i < _tp_size; i++) {
......@@ -907,10 +930,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size,
_ubufs[0].numel(), stream_main););
} else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
......@@ -921,31 +943,33 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
size_t k = A.size(1);
size_t n = B.size(0);
// Get communication and GEMM input chunk sizes
size_t n_chunk = n / _tp_size;
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n_chunk = _ubufs[0].size(0);
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Catch up the main stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0));
for (size_t i = 0; i < _stream_send.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[i], _start_compute, 0));
}
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
......@@ -954,36 +978,30 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW
// GEMM and send/recv chunks
for (int i = 0; i < _tp_size; i++) {
// GEMM chunk
int stream_id = i % _stream_compute.size();
int input_b_chunk_id = (_tp_id + i + 1) % _tp_size;
char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes);
auto input_b_chunk = TensorWrapper(reinterpret_cast<void *>(input_b_chunk_ptr), {n_chunk, k},
B.dtype(), nullptr, nullptr, B.scale_inv());
auto output_chunk =
TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr);
char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk;
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk =
TensorWrapper(reinterpret_cast<void *>(workspace_chunk_ptr),
std::vector<size_t>{workspace_size_chunk}, workspace.dtype());
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]);
use_split_accumulator, _math_sms, _stream_compute[stream_id]);
if (i > 0) {
// P2P communication chunk
int prev_stream_id = (i - 1) % _stream_compute.size();
int send_offset = comm_bytes * (i - 1);
int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
_stream_send);
_stream_send[prev_stream_id]);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
_stream_recv);
}
......@@ -993,8 +1011,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
......@@ -1002,11 +1022,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size,
_ubufs[0].numel(), stream_main););
} else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
......
......@@ -19,6 +19,7 @@
#include <stdio.h>
#include <unistd.h>
#include "common/util/system.h"
#include "userbuffers.h"
#define MAX_THREADS 1024
......
......@@ -6,27 +6,138 @@
#include <transformer_engine/transformer_engine.h>
#include <bit>
#include "./common.h"
#include "./utils.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
namespace transformer_engine {
namespace {
__global__ void __launch_bounds__(1)
update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr,
float* __restrict__ scale_inv_ptr) {
update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr,
float *__restrict__ scale_inv_ptr) {
const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
reciprocal<float>(scale_inv_ptr, scale);
}
} // namespace
void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) {
if (t->scale_inv.dptr != nullptr) {
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) {
NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv.");
update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float*>(t->scale.dptr), reinterpret_cast<float*>(t->scale_inv.dptr));
reinterpret_cast<const float *>(t->scale.dptr),
reinterpret_cast<float *>(t->scale_inv.dptr));
}
}
void checkCuDriverContext(CUstream stream) {
CUcontext ctx;
const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
switch (driver_status) {
case CUDA_SUCCESS:
break;
case CUDA_ERROR_INVALID_CONTEXT:
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &ctx, current_device);
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, ctx);
break;
default:
const char *desc_NVTE_CHECK_CUDA_DRIVER;
cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER);
NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER);
}
}
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = {
{DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
{DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
{DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
{DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
{DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
return dtypeMapping.at(dtype);
}
inline bool isPointerAligned(const void *const ptr, const int alignment) {
const uint64_t ptr_as_uint = reinterpret_cast<uint64_t>(ptr);
return ptr_as_uint % alignment == 0;
}
// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size) {
// Get a function pointer to the cuTensorMapEncodeTiled driver API
static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() {
void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(driver_ptr);
}();
// rank is the number of dimensions of the array
constexpr uint32_t rank = 2;
uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next
uint64_t stride[rank - 1] = {stride_elems * type_size};
// The boxSize is the size of the shared memory buffer that is used as the
// source/destination of a TMA transfer
uint32_t boxSize[rank] = {shmemX, shmemY};
// The distance between elements in units of sizeof(element)
uint32_t elemStride[rank] = {1, 1};
const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
void *dataPtr =
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + offset_elems * type_size);
constexpr int TMA_gmem_alignment = 16; // Alignment of the global memory address
NVTE_CHECK(isPointerAligned(dataPtr, TMA_gmem_alignment),
"Tensor data pointer must be 16B aligned");
const int TMA_needed_size = TMA_gmem_alignment / type_size;
NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size,
"-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
// Create the tensor descriptor.
NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
&tensorMap, // CUtensorMap *tensorMap,
tensorDataType,
rank, // cuuint32_t tensorRank,
dataPtr, // void *globalAddress,
size, // const cuuint64_t *globalDim,
stride, // const cuuint64_t *globalStrides,
boxSize, // const cuuint32_t *boxDim,
elemStride, // const cuuint32_t *elementStrides,
// Interleave patterns can be used to accelerate loading of values that
// are less than 4 bytes long.
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
// Swizzling can be used to avoid shared memory bank conflicts.
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
// L2 Promotion can be used to widen the effect of a cache-policy to a wider
// set of L2 cache lines.
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
// CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
// Any element that is outside of bounds will be set to zero by the TMA transfer.
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}
bool is_supported_by_CC_100() {
int deviceComputeCapability = cuda::sm_arch(cuda::current_device());
return deviceComputeCapability >= 100;
}
} // namespace transformer_engine
......@@ -7,6 +7,7 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
......@@ -22,10 +23,29 @@
#include <vector>
#include "./nvtx.h"
#include "./util/cuda_driver.h"
#include "./util/logging.h"
namespace transformer_engine {
inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries");
size_t ret = 1;
for (size_t i = begin; i < end; ++i) {
ret *= shape[i];
}
return ret;
}
inline size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (const auto &elem : shape) {
ret *= elem;
}
return ret;
}
struct SimpleTensor {
void *dptr;
std::vector<size_t> shape;
......@@ -33,20 +53,114 @@ struct SimpleTensor {
SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype)
: dptr(dptr), shape(shape), dtype(dtype) {}
SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT
: dptr(tensor.data_ptr),
shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim),
dtype(static_cast<DType>(tensor.dtype)) {}
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
operator NVTEBasicTensor() const {
const NVTEShape shape = {this->shape.data(), this->shape.size()};
return {dptr, static_cast<NVTEDType>(dtype), shape};
}
int numel() const {
size_t acc = 1;
for (const auto &dim : shape) {
acc *= dim;
}
return acc;
}
};
struct Tensor {
SimpleTensor data;
SimpleTensor columnwise_data;
SimpleTensor amax;
SimpleTensor scale;
SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv;
NVTEScalingMode scaling_mode;
Tensor()
: data(),
columnwise_data(),
amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32) {}
scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const {
NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr,
"Tensor does not hold any data!");
size_t acc = 1;
if (data.dptr != nullptr) {
for (const auto &dim : data.shape) {
acc *= dim;
}
return acc;
}
// data is empty, use columnwise_data
for (const auto &dim : columnwise_data.shape) {
acc *= dim;
}
return acc;
}
bool has_data() const noexcept { return data.dptr != nullptr; }
bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; }
DType dtype() const {
if (has_data()) return data.dtype;
if (has_columnwise_data()) return columnwise_data.dtype;
// Fallback, used e.g. in workspace
return data.dtype;
}
/*! Matrix height after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_first_dim() const {
if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
return product(data_shape, 1, data_shape.size());
} else {
return product(data_shape, 0, data_shape.size() - 1);
}
}
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return product(data_shape, 0, data_shape.size() - 1);
}
/*! Matrix width after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_last_dim() const {
if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
return data_shape.front();
} else {
return data_shape.back();
}
}
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return data_shape.back();
}
};
template <typename T>
......@@ -62,6 +176,10 @@ using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
#endif
using e8m0_t = uint8_t;
namespace detail {
......@@ -80,6 +198,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
#undef TRANSFORMER_ENGINE_TYPE_NAME
} // namespace detail
......@@ -150,6 +271,10 @@ struct TypeInfo {
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E8M0: { \
using type = byte; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
......@@ -181,6 +306,25 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......@@ -236,15 +380,22 @@ struct TypeInfo {
NVTE_ERROR("Invalid type for 16 bit."); \
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (const auto &elem : shape) {
ret *= elem;
#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \
switch (SCALE_DIM) { \
case 1: { \
constexpr size_t DIM = 1; \
{ __VA_ARGS__ } \
} break; \
case 32: { \
constexpr size_t DIM = 32; \
{ __VA_ARGS__ } \
} break; \
default: { \
NVTE_ERROR("Invalid size of the MX scaling factor."); \
} \
}
return ret;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int log2_ceil(int value) {
int log2_value = 0;
......@@ -269,13 +420,37 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template <>
struct is_fp8<fp8e5m2> : std::true_type {};
// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors
constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_colwise = 128;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
size_t typeToSize(const DType type);
void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
bool is_fp8_dtype(const DType t);
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &type);
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_block_scaling(const NVTEScalingMode &mode) {
return mode != NVTE_DELAYED_TENSOR_SCALING;
}
inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
return is_tensor_scaling(mode);
}
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
/*! \brief Update a tensor's FP8 scale-inverse
*
* The FP8 scale-inverse (dequantization scaling factor) is updated
......@@ -286,6 +461,20 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream);
#define NVTE_API_CALL(api_name) \
transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
void checkCuDriverContext(CUstream stream);
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
inline bool isPointerAligned(const void *const ptr, const int alignment);
// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
const uint32_t offset_elems, const size_t type_size);
bool is_supported_by_CC_100();
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
......@@ -93,17 +93,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const bool supported_ragged_offset_size =
(!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500);
if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) &&
(sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) &&
(((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) &&
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) &&
(head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) &&
(max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) &&
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) &&
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) &&
sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
// 8.9: t3hd, max_s=512, d=64, padding
((cudnn_runtime_version >= 8900 && sm_arch_ < 100 &&
qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv &&
max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 &&
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
// 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
(cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
// 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal}
(cudnn_runtime_version >= 90700 &&
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
!requires_64bit_ragged_offset) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
......@@ -135,8 +149,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
!requires_64bit_ragged_offset) {
flag_m512 = true;
}
if (
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
if ( // architecture
// special conditions for blackwell
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
!(sm_arch_ == 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
// architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
// sequence length
......@@ -218,9 +236,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(cudnn_runtime_version >= 90600 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
// TODO(cyang): fix bug for BRCM + cross-attention on sm100
(sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700)))) ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
(sm_arch_ < 100 || (sm_arch_ == 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700))))) &&
max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)))) &&
// check 64-bit ragged offset support
......
......@@ -227,7 +227,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_options.set_sliding_window_length(window_size_left + 1);
sdpa_options.set_diagonal_band_left_bound(window_size_left + 1);
}
sdpa_options.set_alibi_mask(is_alibi);
......@@ -457,8 +457,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
// keep original batch size because cu_seqlens are created with [b+1] shape
int64_t actual_b = b;
if (is_ragged && cudnn_runtime_version >= 90600) {
......@@ -667,7 +665,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
}
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_sliding_window_length(window_size_left + 1);
sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1);
}
if (cudnn_runtime_version >= 90000) {
......
......@@ -1670,8 +1670,6 @@ void fused_attn_fp8_fwd_impl_v1(
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!");
NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!");
try {
FADescriptor_v1 descriptor{b,
......@@ -1798,36 +1796,33 @@ void fused_attn_fp8_fwd_impl_v1(
// sdpa_options.set_bias(bias);
// }
// if (is_padding) {
// seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_q")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_kv")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// sdpa_options.set_padding_mask(is_padding)
// .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv);
// }
if (is_padding) {
seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv);
}
// if (is_dropout) {
// dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Seed")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Offset")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// sdpa_options.set_dropout(
// dropout_probability, dropout_seed, dropout_offset);
// }
if (is_dropout) {
dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
}
auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8(
Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options);
......@@ -1919,29 +1914,28 @@ void fused_attn_fp8_fwd_impl_v1(
{amax_o, devPtrAmaxO},
{Stats, devPtrM}};
// if (is_bias) {
// variant_pack[bias] = devPtrBias;
// }
/* if (is_bias) {
variant_pack[bias] = devPtrBias;
} */
if (is_padding) {
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void* devActualSeqlenQ = static_cast<int8_t*>(workspace) + plan_workspace_size;
void* devActualSeqlenKV = static_cast<int8_t*>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
static_cast<int32_t*>(devActualSeqlenKV));
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
// if (is_padding) {
// constexpr size_t nthreads_per_block = 128;
// const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
// void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
// void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ)
// + b * sizeof(int32_t);
// cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
// b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
// static_cast<const int32_t *>(devPtrCuSeqlensKV),
// static_cast<int32_t *>(devActualSeqlenQ),
// static_cast<int32_t *>(devActualSeqlenKV));
// variant_pack[seq_q] = devActualSeqlenQ;
// variant_pack[seq_kv] = devActualSeqlenKV;
// }
if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
// if (is_dropout) {
// variant_pack[dropout_seed] = devPtrDropoutSeed;
// variant_pack[dropout_offset] = devPtrDropoutOffset;
// }
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException& e) {
NVTE_ERROR(e.what());
......@@ -1974,8 +1968,6 @@ void fused_attn_fp8_bwd_impl_v1(
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!");
NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!");
try {
FADescriptor_v1 descriptor{b,
......@@ -2151,36 +2143,35 @@ void fused_attn_fp8_bwd_impl_v1(
// }
// }
// if (is_padding) {
// seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_q")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("seq_kv")
// .set_dim({b, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT32));
// sdpa_backward_options.set_padding_mask(is_padding)
// .set_seq_len_q(seq_q)
// .set_seq_len_kv(seq_kv);
// }
if (is_padding) {
seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
sdpa_backward_options.set_padding_mask(is_padding)
.set_seq_len_q(seq_q)
.set_seq_len_kv(seq_kv);
}
// if (is_dropout) {
// dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Seed")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
// .set_name("Offset")
// .set_dim({1, 1, 1, 1})
// .set_stride({1, 1, 1, 1})
// .set_data_type(fe::DataType_t::INT64));
// sdpa_backward_options.set_dropout(
// dropout_probability, dropout_seed, dropout_offset);
// }
if (is_dropout) {
dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT64));
sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
}
auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward(
q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s,
......@@ -2308,34 +2299,32 @@ void fused_attn_fp8_bwd_impl_v1(
{amax_dP, devPtrAmaxdP},
};
// if (is_bias) {
// variant_pack[bias] = devPtrBias;
// if ((bias_b == 1) && (bias_h == h)) {
// variant_pack[dBias] = devPtrdBias;
// } else {
// variant_pack[dBias] = nullptr;
// }
// }
// if (is_padding) {
// constexpr size_t nthreads_per_block = 128;
// const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
// void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
// void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ)
// + b * sizeof(int32_t);
// cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
// b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
// static_cast<const int32_t *>(devPtrCuSeqlensKV),
// static_cast<int32_t *>(devActualSeqlenQ),
// static_cast<int32_t *>(devActualSeqlenKV));
// variant_pack[seq_q] = devActualSeqlenQ;
// variant_pack[seq_kv] = devActualSeqlenKV;
// }
/* if (is_bias) {
variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) {
variant_pack[dBias] = devPtrdBias;
} else {
variant_pack[dBias] = nullptr;
}
} */
if (is_padding) {
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void* devActualSeqlenQ = static_cast<int8_t*>(workspace) + plan_workspace_size;
void* devActualSeqlenKV = static_cast<int8_t*>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, b, static_cast<const int32_t*>(devPtrcuSeqlensQ), // TODO(pass max_b)
static_cast<const int32_t*>(devPtrcuSeqlensKV), static_cast<int32_t*>(devActualSeqlenQ),
static_cast<int32_t*>(devActualSeqlenKV));
variant_pack[seq_q] = devActualSeqlenQ;
variant_pack[seq_kv] = devActualSeqlenKV;
}
// if (is_dropout) {
// variant_pack[dropout_seed] = devPtrDropoutSeed;
// variant_pack[dropout_offset] = devPtrDropoutOffset;
// }
if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException& e) {
......
......@@ -15,6 +15,7 @@
#include "../common.h"
#include "../util/logging.h"
#include "common/util/cuda_runtime.h"
namespace {
......@@ -46,6 +47,95 @@ uint32_t _getAlignment(uintptr_t address) {
}
}
struct GemmParam {
void *A;
void *B;
cublasOperation_t transA;
cublasOperation_t transB;
transformer_engine::DType Atype;
transformer_engine::DType Btype;
void *A_scale_inv;
void *B_scale_inv;
int lda;
int ldb;
GemmParam(cublasOperation_t transA, cublasOperation_t transB)
: A(nullptr),
B(nullptr),
transA(transA),
transB(transB),
Atype(transformer_engine::DType::kNumTypes),
Btype(transformer_engine::DType::kNumTypes),
A_scale_inv(nullptr),
B_scale_inv(nullptr),
lda(0),
ldb(0) {}
};
GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA,
const transformer_engine::Tensor &B, const cublasOperation_t transB,
const int k, const int lda, const int ldb) {
using namespace transformer_engine;
NVTE_CHECK(A.scaling_mode == B.scaling_mode,
"Inputs A and B to GEMM need to have the same scaling mode!");
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret(transA, transB);
ret.lda = lda;
ret.ldb = ldb;
if (is_tensor_scaling(A.scaling_mode)) {
ret.A = A.data.dptr;
ret.A_scale_inv = A.scale_inv.dptr;
if (transA == CUBLAS_OP_T) {
ret.Atype = A.data.dtype;
} else {
ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype;
if (is_fp8_dtype(ret.Atype)) {
int arch = cuda::sm_arch(cuda::current_device());
if (arch < 100) {
// Hopper and Ada - we need to use columnwise_data and change transA
NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!");
ret.A = A.columnwise_data.dptr;
ret.transA = CUBLAS_OP_T;
ret.A_scale_inv = A.columnwise_scale_inv.dptr;
ret.lda = k;
}
}
}
ret.B = B.data.dptr;
ret.B_scale_inv = B.scale_inv.dptr;
if (transB == CUBLAS_OP_T) {
ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype;
if (is_fp8_dtype(ret.Btype)) {
int arch = cuda::sm_arch(cuda::current_device());
if (arch < 100) {
// Hopper and Ada - we need to use columnwise_data and change transA
NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!");
ret.B = B.columnwise_data.dptr;
ret.transB = CUBLAS_OP_N;
ret.B_scale_inv = B.columnwise_scale_inv.dptr;
ret.ldb = k;
}
}
} else {
ret.Btype = B.data.dtype;
}
} else {
// If not tensor scaling (which includes also high precision types), we need to
// use the proper version of data
// We leave the transA/B values as is, since Blackwell supports transposes
ret.A = transA ? A.data.dptr : A.columnwise_data.dptr;
ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype;
ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr;
ret.B = transB ? B.columnwise_data.dptr : B.data.dptr;
ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype;
ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr;
}
return ret;
}
} // namespace
namespace transformer_engine {
......@@ -56,10 +146,13 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, cudaStream_t stream) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB->scale_inv.dptr;
// Return immediately if GEMM is trivial
if (m <= 0 || n <= 0) {
return;
}
NVTE_CHECK(k > 0);
const GemmParam &param = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb);
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
void *D_scale = outputD->scale.dptr;
......@@ -72,15 +165,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
counter = inputCounter->data.dptr;
}
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype);
const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype);
const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!");
// check consistency of arguments:
......@@ -117,17 +211,17 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, transa == CUBLAS_OP_N ? m : k,
transa == CUBLAS_OP_N ? k : m, lda));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, transb == CUBLAS_OP_N ? k : n,
transb == CUBLAS_OP_N ? n : k, ldb));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k,
param.transA == CUBLAS_OP_N ? k : m, param.lda));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n,
param.transB == CUBLAS_OP_N ? n : k, param.ldb));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa)));
&param.transA, sizeof(param.transA)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb)));
&param.transB, sizeof(param.transB)));
// Set math SM count
if (math_sm_count != 0) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
......@@ -143,12 +237,53 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode, sizeof(fastAccuMode)));
// Scaling factors.
#if CUDA_VERSION >= 12080
cublasLtMatmulMatrixScale_t scaling_mode;
#endif
if ((is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode))) {
void *A_scale_inverse = param.A_scale_inv;
void *B_scale_inverse = param.B_scale_inv;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
#if CUDA_VERSION >= 12080
scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
} else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) {
fp8e8m0 *A_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.A_scale_inv);
fp8e8m0 *B_scale_inverse = reinterpret_cast<fp8e8m0 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublasLtGetVersion() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride)));
}
#endif
} else {
NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " +
to_string(inputB->scaling_mode) + ".");
}
#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
#endif
if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output
C = nullptr;
......@@ -156,8 +291,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
// For FP8 output, cuBLAS requires C_type to be same as bias_type
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd));
#if CUDA_VERSION >= 12080
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode)));
#endif
// For FP8 output, cuBLAS requires C_type to match bias_type and
// be FP16/BF16
const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF;
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd));
} else {
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
}
......@@ -235,8 +376,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(A));
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(B));
const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.A));
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
......@@ -260,8 +401,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc,
static_cast<const void *>(&one), /* alpha */
A, /* A */
Adesc, B, /* B */
param.A, /* A */
Adesc, param.B, /* B */
Bdesc, static_cast<const void *>(&beta), /* beta */
C, /* C */
Cdesc, D, /* D */
......@@ -270,7 +411,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
workspaceSize, stream)); /* stream */
// Update FP8 scale-inv in output tensor
if (is_fp8_dtype(outputD->data.dtype)) {
// Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated.
// TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be
// calculated here.
if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) {
update_tensor_scale_inv(outputD, stream);
}
......@@ -309,9 +453,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
const size_t A0 = inputA->flat_first_dim();
const size_t A1 = inputA->flat_last_dim();
const size_t B0 = inputB->flat_first_dim();
const size_t B1 = inputB->flat_last_dim();
const int m = transa ? A0 : A1;
const int k = transa ? A1 : A0;
const int n = transb ? B1 : B0;
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
......@@ -357,6 +506,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode),
"Atomic GEMM only supports delayed scaling.");
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
......
......@@ -19,7 +19,9 @@ extern "C" {
/* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */
/*! \brief Compute activation of the input.
/*! \brief Computes activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
......@@ -39,17 +41,59 @@ enum class NVTE_Activation_Type {
SREGLU,
};
/*! \brief Computes the GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Compute activation gradient.
/*! \brief Computes the GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
......@@ -59,19 +103,57 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the SiLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Compute gated activation of the input.
/*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
......@@ -80,15 +162,54 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
*/
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Swish activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Quick GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Squared ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Compute gated activation gradient.
/*! \brief Computes the gated GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
......@@ -97,15 +218,51 @@ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the gated Swish activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the gated Quick GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the gated Squared ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
......
......@@ -5,7 +5,7 @@
************************************************************************/
/*! \file cast.h
* \brief Functions to cast to/from FP8.
* \brief Functions to cast to/from FP8/MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_CAST_H_
......@@ -17,21 +17,200 @@
extern "C" {
#endif
/*! \brief Cast tensor to FP8.
/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer)
* The implementation is per the microscaling format MXFP8 defined by the OCP specification:
* https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
*
* Supported modes of scaling (live scaling):
* 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes:
* - the scaled output tensor
* - the corresponding scaling factors
* The scaling factors are computed for blocks of the shape [1,32]
* (i.e., each scaling factor spans 32 contiguous elements along rows).
*
* 2) Columwise scaling (along the dim=1) computes one set of the output data.
* The scaling factors are computed for blocks of the shape [32,1]
* (i.e., each scaling factor spans 32 contiguous elements along columns).
*
* 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1)
* computes two sets of the output data: both 1) and 2).
*
* The shape of the MX block must be specified in the 'output' argument,
* and can be either [1,32] or [32,1] as no other shapes are currently supported.
*
* To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter
* of the output tensor should be set to 0.
*/
/*! \brief Casts input tensor to FP8/MXFP8.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8 tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Cast tensor from FP8.
/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] noop Noop tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream);
/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[out] output Output tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workplace, cudaStream_t stream);
/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor to be cast.
* \param[in] act_input Activation input tensor.
* \param[in,out] output Output FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Casts input tensor from reduced to higher precision.
* If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING,
* the block dequantization (MXFP8) of the specified shape of the block will be used.
* In case of the MXFP8 dequantization, the dequantized values are stored to the rowwise
* data of the output tensor, regardless of whether the row- or columnwise scaling is used.
*
* \param[in] input Input FP8/MXFP8 tensor to be cast.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -17,11 +17,26 @@
extern "C" {
#endif
/*! \brief Transposes the input, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
*
* \param[in] input Input tensor.
* \param[in] noop Noop tensor.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream);
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop,
NVTETensor cast_output, NVTETensor transposed_output,
/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
*
* \param[in] input Input tensor.
* \param[in] noop Noop tensor.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream);
#ifdef __cplusplus
......
......@@ -53,6 +53,8 @@ class CommOverlapCore {
int _cga_size;
int _use_ce;
int _ub_reg;
int _gemm_priority;
int _comm_priority;
bool _atomic_gemm{false};
bool _is_p2p{false};
......@@ -65,10 +67,13 @@ class CommOverlapCore {
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;
public:
CommOverlapCore() {} // dummy constructor for exposing type to Python
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm,
bool set_sm_margin, bool use_ce, bool atomic_gemm);
int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority,
int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm);
virtual ~CommOverlapCore();
......@@ -77,25 +82,76 @@ class CommOverlapCore {
_ubuf_scale_inv_initialized = true;
}
TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; }
bool is_fp8_ubuf() { return _ubuf.element_size() == 1; }
virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, CommOverlapType comm_type,
TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &B_copy, cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
}; // CommOverlapCore
class CommOverlapBase : public CommOverlapCore {
protected:
int _rs_kernel_type;
bool _rs_overlap_first_gemm;
cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy;
public:
CommOverlapBase() {} // dummy constructor for exposing type to Python
CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3,
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false);
int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
bool set_sm_margin = true, bool atomic_gemm = false,
bool rs_overlap_first_gemm = false);
virtual ~CommOverlapBase();
......@@ -103,97 +159,124 @@ class CommOverlapBase : public CommOverlapCore {
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main);
void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) override;
void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
/*
** Split FPROP GEMM + ReduceScatter
*/
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap,
TensorWrapper &rs_output, cudaStream_t stream_main);
void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override;
/*
** Split FPROP GEMM + ReduceScatter
*/
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output,
cudaStream_t stream_main);
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override;
}; // CommOverlapBase
class CommOverlapP2PBase : public CommOverlapCore {
protected:
bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false};
bool _aggregate;
int _next_rank;
int _prev_rank;
int _rank_round_tp;
int _aggregate;
int _num_ubuf_chunks;
int _self_chunk_id;
std::vector<TensorWrapper> _ubufs;
cudaStream_t _stream_send;
std::vector<cudaStream_t> _stream_send;
cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv;
public:
CommOverlapP2PBase() {} // dummy constructor for exposing type to Python
CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype, int myrank,
int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS,
int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false,
bool use_ce = true, bool atomic_gemm = false, bool aggregate = false);
int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0,
int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true,
bool atomic_gemm = false, bool aggregate = false);
virtual ~CommOverlapP2PBase();
TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id);
void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main);
void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) override;
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main);
cudaStream_t stream_main) override;
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main);
void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override;
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main);
cudaStream_t stream_main) override;
}; // CommOverlapP2PBase
} // namespace transformer_engine
......
......@@ -28,16 +28,10 @@ extern "C" {
* \param[in] amax_history History of maximum absolute values.
* Shape: [history_length, num_scales]
* \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales]
* \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales]
* \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be
* empty, in which case all scale_inv entries are updated.
* Shape: [num_scales]
* \param[out] updated_amax_history Updated history of maximum absolute values.
* Shape: [history_length, num_scales]
* \param[out] updated_scale Updated scaling factor for casting to FP8.
* Shape: [num_scales]
* \param[out] updated_scale_inv Updated scaling factor for casting from FP8.
* Shape: [num_scales]
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and
* "most_recent".
* \param[in] fp8_dtype FP8 datatype.
......@@ -45,9 +39,8 @@ extern "C" {
* \param[in] stream CUDA stream.
*/
void nvte_delayed_scaling_recipe_amax_and_scale_update(
const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv,
const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale,
NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history,
NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
cudaStream_t stream);
/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction.
......@@ -55,7 +48,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
* Operations performed include, updating the most recent amax history
* with the relevant segment of global reduction buffer if it's not 0,
* rotating the amax history based on the rule below, and updating the
* scales and scale_invs.
* scales.
*
* The amax history is rotated by -1 (e.g. the first entry shifts to
* the last, the last entry shifts to the second to last) and the
......@@ -69,8 +62,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
* Shape: num_tensors x [history_length, num_scales]
* \param[in,out] scales List of scaling factors for casting to FP8.
* Shape: num_tensors x [num_scales]
* \param[in,out] scale_invs List of scaling factors for casting from FP8.
* Shape: num_tensors x [num_scales]
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and
* "most_recent".
* \param[in] fp8_dtype FP8 datatype.
......@@ -79,8 +70,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
*/
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales, std::vector<NVTETensor> scale_invs,
const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream);
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
float margin, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file cast.h
* \brief Functions to cast to/from FP8.
*/
#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_
#define TRANSFORMER_ENGINE_SWIZZLE_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM
*
* \param[in] input Input tensor with non-swizzled scale_inv.
* \param[in,out] output Output tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_SWIZZLE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment