Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
......@@ -30,6 +30,7 @@ enum NVTEDType {
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */
kNVTENumTypes /*!< Number of supported types */
};
......@@ -43,6 +44,42 @@ struct NVTEShape {
size_t ndim;
};
/*! \struct NVTEBasicTensor
* \brief A basic tensor type used to populate parameters of NVTETensor.
* It does not own the memory it points to.
*/
struct NVTEBasicTensor {
void *data_ptr;
NVTEDType dtype;
NVTEShape shape;
};
/*! \enum NVTETensorParam
* \brief Indicates the kind of the tensor parameter to set/get.
*/
enum NVTETensorParam {
kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */
kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */
kNVTEScale = 2, /*!< Scale tensor */
kNVTEAmax = 3, /*!< Amax tensor */
kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTENumTensorParams
};
/*! \enum NVTEScalingMode
* \brief Granularity of scaling:
*/
enum NVTEScalingMode {
/*! Single scale per tensor, computed in delayed manner.
Used also for high precision data, without scaling */
NVTE_DELAYED_TENSOR_SCALING = 0,
/*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING = 1,
NVTE_INVALID_SCALING
};
/*! \brief TE Tensor type
*
* NVTETensor is a contiguous tensor type storing a pointer
......@@ -53,21 +90,15 @@ typedef void *NVTETensor;
/*! \brief Create a new TE tensor.
*
* Create a new TE tensor with a given shape, datatype and data.
* Create a new TE tensor. Before use its parameters need to be set.
* TE tensors are just wrappers on top of raw data and do not
* own memory.
*
* \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
* \param[in] scaling_mode Scaling mode of the tensor.
*
* \return A new TE tensor.
*/
NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype,
float *amax_dptr, float *scale_dptr, float *scale_inv_dptr);
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode);
/*! \brief Destroy a TE tensor.
*
......@@ -78,14 +109,22 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType
*/
void nvte_destroy_tensor(NVTETensor tensor);
/*! \brief Get a raw pointer to the tensor's data.
/*! \brief Get a raw pointer to the tensor's rowwise data.
*
* \param[in] tensor Tensor.
*
* \return A raw pointer to tensor's data.
* \return A raw pointer to tensor's rowwise data.
*/
void *nvte_tensor_data(const NVTETensor tensor);
/*! \brief Get a raw pointer to the tensor's columnwise data.
*
* \param[in] tensor Tensor.
*
* \return A raw pointer to tensor's columnwise data.
*/
void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Get a tensor's data shape.
*
* \param[in] tensor Tensor.
......@@ -94,6 +133,14 @@ void *nvte_tensor_data(const NVTETensor tensor);
*/
NVTEShape nvte_tensor_shape(const NVTETensor tensor);
/*! \brief Get a tensor's data shape.
*
* \param[in] tensor Tensor.
*
* \return A shape of the input tensor.
*/
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor);
/*! \brief Get a tensor's number of dimensions.
*
* \param[in] tensor Tensor.
......@@ -159,6 +206,46 @@ float *nvte_tensor_scale(const NVTETensor tensor);
*/
float *nvte_tensor_scale_inv(const NVTETensor tensor);
/*! \brief Get a tensor's scale_inv shape.
*
* \param[in] tensor Tensor.
*
* \return A scale_inv shape of the input tensor.
*/
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor);
/*! \brief Reset tensor value to zero.
*
* \param[in] tensor Tensor.
*
* \return A scale_inv shape of the input tensor.
*/
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream);
/*! \brief Set a parameter of the tensor.
*
* \param[in/out] tensor Tensor.
* \param[in] param_name The parameter to be set.
* \param[in] param The value to be set.
*/
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
const NVTEBasicTensor *param);
/*! \brief Get a value of the parameter of the tensor.
*
* \param[in] tensor Tensor.
* \param[in] param_name The parameter to be set.
*/
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name);
/*! \brief Get the granularity of scaling of this tensor.
*
* \param[in] tensor Tensor.
*
* \return A struct containing the granularity of tensor's scaling.
*/
NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor);
/*! \struct NVTETensorPack
\brief Pack of tensors, generally used for auxiliary outputs.
*/
......@@ -201,6 +288,7 @@ enum class DType {
kBFloat16 = 5,
kFloat8E4M3 = 6,
kFloat8E5M2 = 7,
kFloat8E8M0 = 8,
kNumTypes
};
......@@ -220,12 +308,23 @@ class TensorWrapper {
* \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_shape Shape of scale_inv
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*/
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr,
float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr)
: tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype), amax_dptr,
scale_dptr, scale_inv_dptr)) {}
float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr,
const NVTEShape scale_inv_shape = defaultShape,
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) {
tensor_ = nvte_create_tensor(scaling_mode);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(dtype), shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data);
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape};
nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax);
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape};
nvte_set_tensor_param(&tensor_, kNVTEScale, &scale);
NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv);
}
/*! \brief Constructs new TensorWrapper.
*
......@@ -238,19 +337,23 @@ class TensorWrapper {
* \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_shape Shape of scale_inv
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*/
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype,
float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr)
float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1},
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
: TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr,
scale_inv_dptr) {}
scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()},
scaling_mode) {}
/*! \brief Constructs new empty TensorWrapper.
*
* Create a new empty TE tensor which holds nothing.
*/
TensorWrapper() : TensorWrapper(nullptr, std::vector<size_t>(), DType::kFloat32) {}
explicit TensorWrapper(const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
: tensor_(nvte_create_tensor(scaling_mode)) {}
/*! \brief TensorWrapper destructor. */
~TensorWrapper() { nvte_destroy_tensor(tensor_); }
......@@ -283,6 +386,70 @@ class TensorWrapper {
return *this;
}
// Parameter setters
template <typename ShapeType>
TensorWrapper &set_parameter(const NVTETensorParam param, void *dptr, DType type,
const ShapeType &shape) noexcept {
NVTEShape nvte_shape = this->convertShape(shape);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(type), nvte_shape};
nvte_set_tensor_param(&tensor_, param, &data);
return *this;
}
template <typename ShapeType>
TensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTERowwiseData, dptr, type, shape);
}
template <typename ShapeType>
TensorWrapper &set_columnwise_data(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTEColumnwiseData, dptr, type, shape);
}
template <typename ShapeType>
TensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTEScale, dptr, type, shape);
}
template <typename ShapeType>
TensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTEAmax, dptr, type, shape);
}
template <typename ShapeType>
TensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTERowwiseScaleInv, dptr, type, shape);
}
template <typename ShapeType>
TensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape);
}
// Parameter getters
NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept {
return nvte_get_tensor_param(tensor_, param);
}
NVTEBasicTensor get_rowwise_data() const noexcept { return get_parameter(kNVTERowwiseData); }
NVTEBasicTensor get_columnwise_data() const noexcept {
return get_parameter(kNVTEColumnwiseData);
}
NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEScale); }
NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEAmax); }
NVTEBasicTensor get_rowwise_scale_inv() const noexcept {
return get_parameter(kNVTERowwiseScaleInv);
}
NVTEBasicTensor get_columnwise_scale_inv() const noexcept {
return get_parameter(kNVTEColumnwiseScaleInv);
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
......@@ -298,6 +465,15 @@ class TensorWrapper {
return nvte_tensor_shape(tensor_);
}
/*! \brief Get the shape of this TensorWrapper.
*
* \return Shape of this TensorWrapper.
*/
const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0};
return nvte_tensor_columnwise_shape(tensor_);
}
/*! \brief Get the size of this TensorWrapper in the given dimension.
*
* \param[in] size_t Dimension index.
......@@ -366,6 +542,15 @@ class TensorWrapper {
return nvte_tensor_data(tensor_);
}
/*! \brief Get a raw pointer to the tensor's data.
*
* \return A raw pointer to tensor's data.
*/
void *columnwise_dptr() const noexcept {
if (tensor_ == nullptr) return nullptr;
return nvte_tensor_columnwise_data(tensor_);
}
/*! \brief Get a pointer to the tensor's amax data.
*
* \return A pointer to tensor's amax data.
......@@ -393,7 +578,34 @@ class TensorWrapper {
return nvte_tensor_scale_inv(tensor_);
}
/*! \brief Get the scale_inv_shape of this TensorWrapper.
*
* \return scale_inv_shape of this TensorWrapper.
*/
const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0};
return nvte_tensor_scale_inv_shape(tensor_);
}
/*! \brief Get a scaling mode of the tensor.
*
* \return Scaling mode of the tensor.
*/
NVTEScalingMode scaling_mode() const noexcept {
if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING;
return nvte_tensor_scaling_mode(tensor_);
}
void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); }
static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = {&defaultData, 1};
private:
NVTEShape convertShape(const NVTEShape &s) { return s; }
NVTEShape convertShape(const std::vector<size_t> &s) { return {s.data(), s.size()}; }
/*! \brief Wrapped NVTETensor. */
NVTETensor tensor_ = nullptr;
};
......
......@@ -20,16 +20,16 @@ extern "C" {
/*! \brief Cast and transpose the input.
*
* This function casts the input and produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
* - rowwise data in `output` is the result of the cast
* - columnwise data in `output` is the transposed result of the cast.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] output Result of the cast and transpose.
* Shape of the rowwise data: [N, H].
* Shape of the columnwise data: [H, N]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output,
NVTETensor transposed_output, cudaStream_t stream);
void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Transpose the input.
*
......@@ -41,25 +41,24 @@ void nvte_transpose(const NVTETensor input, NVTETensor transposed_output, cudaSt
/*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension.
*
* This function casts the input and produces 3 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
* This function casts the input and produces 2 results:
* - `output` is the result of the cast (rowwise data) and transposed cast (columnwise data)
* - `dbias` is the result of the reduction of the input along the first dimension.
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] output Result of the cast and transpose.
* Shape of the rowwise data: [N, H].
* Shape of the columnwise data: [H, N]
* \param[out] dbias Result of the reduction of the input along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension.
*
......@@ -81,103 +80,243 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Cast and transpose multiple tensors.
*
* This function casts each input tensor and produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D input tensors.
* \param[in,out] cast_output_list List of casted tensors. Dimensions
* match tensors in input_list.
* \param[in,out] transposed_output_list List of casted and transposed
* tensors. Dimensions are transpose
* of tensors in input_list.
* \param[in,out] output_list List of casted tensors. Dimensions
* of their rowwise data members match
* tensors in input_list. Dimensions of
* their columnwise data members are
* transposed.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
NVTETensor* cast_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream);
NVTETensor* output_list, cudaStream_t stream);
/*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally,
* reduce the result of the SiLU backward along the first dimension.
/*! \brief Compute backward of GeLU operation on the input, then cast and transpose.
* Additionally, reduce the result of the GeLU backward along the first dimension.
*
* This function produces 3 results:
* - `cast_output` is equal to `cast(dact(input))`
* - `transposed_output` is equal to `transpose(cast(dact(input)))`
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
* - `dbias` is equal to `reduce(dact(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] act_input Tensor used as input to the forward of SiLU operation.
* \param[in] act_input Tensor used as input for the operation of forward activation.
* Shape [N, H].
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] dbias Result of the reduction of the dSiLU(input) along the
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H].
* Shape of columnwise data: [H, N].
* \param[out] dbias Result of the reduction of the dact(input) along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute backward of SiLU operation on the input, then cast and transpose.
* Additionally, reduce the result of the SiLU backward along the first dimension.
*
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
* - `dbias` is equal to `reduce(dact(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] act_input Tensor used as input for the operation of forward activation.
* Shape [N, H].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H].
* Shape of columnwise data: [H, N].
* \param[out] dbias Result of the reduction of the dact(input) along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute backward of ReLU operation on the input, then cast and transpose.
* Additionally, reduce the result of the ReLU backward along the first dimension.
*
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
* - `dbias` is equal to `reduce(dact(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] act_input Tensor used as input for the operation of forward activation.
* Shape [N, H].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H].
* Shape of columnwise data: [H, N].
* \param[out] dbias Result of the reduction of the dact(input) along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute backward of the Quick GeLU operation on the input, then cast and transpose.
* Additionally, reduce the result of the Quick GeLU backward along the first dimension.
*
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
* - `dbias` is equal to `reduce(dact(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] act_input Tensor used as input for the operation of forward activation.
* Shape [N, H].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H].
* Shape of columnwise data: [H, N].
* \param[out] dbias Result of the reduction of the dact(input) along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute backward of the Squared ReLU operation on the input, then cast and transpose.
* Additionally, reduce the result of the Squared ReLU backward along the first dimension.
*
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
* - `dbias` is equal to `reduce(dact(input), axis=0)`
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] act_input Tensor used as input for the operation of forward activation.
* Shape [N, H].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H].
* Shape of columnwise data: [H, N].
* \param[out] dbias Result of the reduction of the dact(input) along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output.
/*! \brief Computes the gated GeLU activation of the input, additionally casts and transposes
* the output.
*
* This function produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of GeGLU operation.
* \param[in] gated_act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] cast_output Result of the cast. Shape: [N, H * 2].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H * 2].
* Shape of columnwise data: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU
*/
void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Swish activation of the input,
* additionally casts and transposes the output.
*
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H * 2].
* Shape of columnwise data: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated ReLU activation of the input,
* additionally casts and transposes the output.
*
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H * 2].
* Shape of columnwise data: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Quick GeLU activation of the input,
* additionally casts and transposes the output.
*
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H * 2].
* Shape of columnwise data: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Squared ReLU activation of the input,
* additionally casts and transposes the output.
*
* This function produces 2 results:
* - rowwise data of `output` is equal to `cast(dact(input))`
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
* Shape of rowwise data: [N, H * 2].
* Shape of columnwise data: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream);
NVTETensor output, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -15,6 +15,7 @@
#include <numeric>
#include "transformer_engine/normalization.h"
#include "transformer_engine/transformer_engine.h"
/*
......@@ -38,13 +39,21 @@ Compute always in FP32
namespace transformer_engine {
namespace normalization {
TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype,
DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size,
bool zero_centered_gamma, bool is_tuned) {
cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
: cudnn_frontend::NormFwdPhase_t::INFERENCE;
}
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
bool is_tuned, NVTEScalingMode mode, bool training) {
// TODO: Add scaling_mode to general_key is needed
uint64_t general_key = static_cast<uint32_t>(itype) | (static_cast<uint32_t>(otype) << 3) |
(static_cast<uint32_t>(ctype) << 6) | (static_cast<uint32_t>(wtype) << 9) |
(uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 |
(uint32_t(zero_centered_gamma) << 16);
(uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) |
(uint32_t(mode) << 19) | (uint32_t(training) << 22);
return std::make_tuple(general_key, batch_size, hidden_size, is_tuned);
}
......@@ -64,8 +73,8 @@ TeNormalizationPlan<KernelParamsType>::TeNormalizationPlan(
kernel_params.fp8_out = is_fp8_dtype(otype);
}
// TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those
auto key =
get_key(NormType, NormStage, wtype, itype, otype, ctype, 0, hidden_size, false, is_tuned);
auto key = get_key(NVTE_Norm_Backend::Te, NormType, NormStage, wtype, itype, otype, ctype, 0,
hidden_size, false, is_tuned);
_kernel = KernelRegistry::getKernel(key);
this->_build();
......@@ -179,13 +188,25 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
DType wtype, DType itype, DType otype, DType ctype,
const size_t batch_size, const size_t hidden_size,
const size_t sm_count,
const bool zero_centered_gamma)
: _fp8_out(is_fp8_dtype(otype)), _zero_centered(zero_centered_gamma) {
const bool zero_centered_gamma,
const NVTEScalingMode mode, bool training)
: _fp8_out(is_fp8_dtype(otype)),
_zero_centered(zero_centered_gamma),
_training(training),
_norm_stage(NormStage),
_norm_type(NormType) {
static_assert(CUDNN_FRONTEND_VERSION >= 10601,
"CUDNN_FRONTEND_VERSION should be at least 1.6.1!");
namespace fe = cudnn_frontend;
if (is_tensor_scaling(mode)) {
_ndim_scale_block = 0;
} else {
NVTE_CHECK(mode == NVTE_MXFP8_1D_SCALING, "Unsupported scaling mode.");
_ndim_scale_block = 1;
}
_scalar_dptr = std::make_unique<char[]>(typeToSize(wtype));
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
......@@ -213,7 +234,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
.set_dim({1, hidden_dim, 1, 1})
.set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
.set_data_type(get_cudnn_fe_dtype(wtype)));
if (zero_centered_gamma) {
if (_zero_centered) {
_scalar_offset = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("one")
.set_dim({1, 1, 1, 1})
......@@ -230,41 +251,42 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
}
// Create graph computation nodes
if (NormStage == NVTE_Norm_Stage::Forward) {
if (_norm_stage == NVTE_Norm_Stage::Forward) {
_eps = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("epsilon")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ctype))
.set_is_pass_by_value(true));
if (NormType == NVTE_Norm_Type::LayerNorm) {
if (_norm_type == NVTE_Norm_Type::LayerNorm) {
_beta = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, hidden_dim, 1, 1})
.set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
.set_data_type(get_cudnn_fe_dtype(wtype)));
auto norm_options = fe::graph::Layernorm_attributes()
.set_forward_phase(fe::NormFwdPhase_t::TRAINING)
.set_forward_phase(get_cudnn_forward_phase(_training))
.set_epsilon(_eps)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
auto ret = _graph.layernorm(_x, _gamma, _beta, norm_options);
std::tie(_z, _mean, _rsigma) = std::make_tuple(ret[0], ret[1], ret[2]);
_mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
} else if (NormType == NVTE_Norm_Type::RMSNorm) {
if (_training) _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
} else {
auto norm_options = fe::graph::Rmsnorm_attributes()
.set_forward_phase(fe::NormFwdPhase_t::TRAINING)
.set_forward_phase(get_cudnn_forward_phase(_training))
.set_epsilon(_eps)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
auto ret = _graph.rmsnorm(_x, _gamma, norm_options);
std::tie(_z, _rsigma) = std::make_tuple(ret[0], ret[1]);
}
_rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
if (_training) _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
const auto ZDtype = _fp8_out ? ctype : otype;
_z->set_output(!_fp8_out).set_data_type(get_cudnn_fe_dtype(ZDtype));
if (_fp8_out) {
if (_ndim_scale_block == 0) { // tensor_scaling
// create a scale node
_z_scale = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("z_scale")
......@@ -283,6 +305,43 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
.set_mode(fe::ReductionMode_t::AMAX)
.set_compute_data_type(get_cudnn_fe_dtype(ctype)));
_amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1});
_one_for_div = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("one_for_div")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ctype))
.set_is_pass_by_value(true));
auto div_options = fe::graph::Pointwise_attributes()
.set_mode(fe::PointwiseMode_t::DIV)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
_z_scale_inv = _graph.pointwise(_one_for_div, _z_scale, div_options);
_z_scale_inv->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
} else if (_ndim_scale_block == 1) { // 1d block scaling
auto z_2d = _graph.reshape(_z, fe::graph::Reshape_attributes());
z_2d->set_dim({batch_dim, hidden_dim});
auto mx_quantize_row_opts = fe::graph::Block_scale_quantize_attributes()
.set_block_size(32)
.set_axis(1)
.set_transpose(false);
auto bs_row_ret = _graph.block_scale_quantize(z_2d, mx_quantize_row_opts);
std::tie(_z_mx_row, _sf_row) = std::make_tuple(bs_row_ret[0], bs_row_ret[1]);
_z_mx_row->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
_sf_row->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0); //TODO
if (_training) {
auto mx_quantize_col_opts = fe::graph::Block_scale_quantize_attributes()
.set_block_size(32)
.set_axis(0)
.set_transpose(false);
auto bs_col_ret = _graph.block_scale_quantize(z_2d, mx_quantize_col_opts);
std::tie(_z_mx_col, _sf_col) = std::make_tuple(bs_col_ret[0], bs_col_ret[1]);
_z_mx_col->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
_sf_col->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0);
}
} else {
NVTE_ERROR("Unsupported scaling mode.");
}
}
} else {
_dz = _graph.tensor(fe::graph::Tensor_attributes()
......@@ -299,7 +358,7 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
.set_dim({batch_dim, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ctype)));
if (NormType == NVTE_Norm_Type::LayerNorm) {
if (_norm_type == NVTE_Norm_Type::LayerNorm) {
auto norm_options = fe::graph::Layernorm_backward_attributes()
.set_saved_mean_and_inv_variance(_mean, _rsigma)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
......@@ -341,10 +400,14 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
void* mean_dptr, void* eps_dptr, void* rsigma_dptr,
void* workspace_dptr, cudaStream_t stream) {
// Binding data pointers to graph tensors
_variant_pack = {{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_eps, eps_dptr}};
_variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}};
// layernorm should have valid mean_dptr and beta_dptr
if (mean_dptr && beta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_beta, beta_dptr}});
if (_training) _variant_pack.insert({{_rsigma, rsigma_dptr}});
if (_norm_type == NVTE_Norm_Type::LayerNorm) {
_variant_pack.insert({{_beta, beta_dptr}});
if (_training) _variant_pack.insert({{_mean, mean_dptr}});
}
if (_zero_centered)
_variant_pack.insert(
......@@ -352,16 +415,24 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
else
_variant_pack.insert({{_gamma, gamma_dptr}});
if (_fp8_out)
if (_fp8_out && _ndim_scale_block == 0) {
_variant_pack.insert({{_one_for_div, reinterpret_cast<void*>(_one_dptr.get())},
{_z_scale, z->scale.dptr},
{_z_scale_inv, z->scale_inv.dptr},
{_amax, z->amax.dptr},
{_z_fp8, z->data.dptr}});
} else if (_fp8_out && _ndim_scale_block == 1) {
_variant_pack.insert({{_z_mx_row, z->data.dptr}, {_sf_row, z->scale_inv.dptr}});
if (_training)
_variant_pack.insert(
{{_z_scale, z->scale.dptr}, {_amax, z->amax.dptr}, {_z_fp8, z->data.dptr}});
else
{{_z_mx_col, z->columnwise_data.dptr}, {_sf_col, z->columnwise_scale_inv.dptr}});
} else {
_variant_pack.insert({{_z, z->data.dptr}});
}
// Execute the computation
NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
if (_fp8_out) update_tensor_scale_inv(z, stream);
}
void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr,
......@@ -389,11 +460,12 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned) {
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
const NVTEScalingMode mode, const bool training) {
const DType ctype = DType::kFloat32;
bool is_tuned = is_aligned && (batch_size % 4 == 0);
auto key = get_key(NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size,
zero_centered_gamma, is_tuned);
auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size,
hidden_size, zero_centered_gamma, is_tuned, mode, training);
auto it = normalizationPlanMap.find(key);
if (it != normalizationPlanMap.end()) {
......@@ -404,7 +476,7 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
if (NormBackend == NVTE_Norm_Backend::Cudnn) {
plan = std::make_unique<CudnnNormalizationPlan>(NormType, NormStage, wtype, itype, otype, ctype,
batch_size, hidden_size, sm_count,
zero_centered_gamma);
zero_centered_gamma, mode, training);
} else if (NormStage == NVTE_Norm_Stage::Forward) {
plan = std::make_unique<TeNormalizationPlan<ForwardKernelParams>>(
NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
......
......@@ -154,9 +154,12 @@ struct TupleHash {
}
};
TupleKeyType get_key(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype,
DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size,
bool zero_centered_gamma, bool is_tuned);
// Note: the default mode here should match with the default mode with QTensor
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING,
bool training = true);
template <typename KernelParamsType>
class TeNormalizationRegistry {
......@@ -257,7 +260,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
DType itype, DType otype, DType ctype, const size_t batch_size,
const size_t hidden_size, const size_t sm_count,
const bool zero_centered_gamma);
const bool zero_centered_gamma, const NVTEScalingMode mode,
const bool training);
std::vector<size_t> getWorkspaceShape() const override;
......@@ -273,10 +277,17 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
void _build() override;
const bool _zero_centered, _fp8_out;
int _ndim_scale_block;
const NVTE_Norm_Stage _norm_stage;
const NVTE_Norm_Type _norm_type;
std::unique_ptr<char[]> _scalar_dptr;
std::unique_ptr<float> _one_dptr = std::make_unique<float>(1.0f);
// FWD
std::shared_ptr<fe::graph::Tensor_attributes> _x, _gamma_zero, _scalar_offset, _gamma, _beta,
_eps, _mean, _rsigma, _z, _z_scale, _amax, _z_fp8;
_eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8;
// MX FWD
std::shared_ptr<fe::graph::Tensor_attributes> _z_mx_row, _z_mx_col, _sf_row, _sf_col;
const bool _training;
// BWD
std::shared_ptr<fe::graph::Tensor_attributes> _dz, _dx, _dgamma, _dbeta;
......@@ -292,12 +303,11 @@ class NormalizationPlanRegistry {
return instance;
}
NormalizationPlanBase* getNormalizationPlan(NVTE_Norm_Backend NormBackend,
NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
DType wtype, DType itype, DType otype,
const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma,
const bool is_aligned);
NormalizationPlanBase* getNormalizationPlan(
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true);
private:
NormalizationPlanRegistry() {}
......@@ -356,15 +366,12 @@ struct TypeToDType<byte> {
static int \
register_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE = \
TeNormalizationRegistry<NORM_STAGE##KernelParams>::registerFunction( \
(get_key(NVTE_Norm_Type::NORM_TYPE, NVTE_Norm_Stage::NORM_STAGE, \
(TypeToDType<WTYPE>::value), (TypeToDType<ITYPE>::value), \
(TypeToDType<OTYPE>::value), (TypeToDType<CTYPE>::value), 0, HIDDEN_SIZE, \
0, IS_TUNED(LAUNCH_TYPE))), \
(get_key(NVTE_Norm_Backend::Te, NVTE_Norm_Type::NORM_TYPE, \
NVTE_Norm_Stage::NORM_STAGE, (TypeToDType<WTYPE>::value), \
(TypeToDType<ITYPE>::value), (TypeToDType<OTYPE>::value), \
(TypeToDType<CTYPE>::value), 0, HIDDEN_SIZE, 0, IS_TUNED(LAUNCH_TYPE))), \
FUNC_NAME)
// For FP8 only
void ComputeScaleInv(void* scale, void* scale_inv);
// Alignment check
template <size_t Alignment = 16, typename... Args>
bool is_ptr_aligned(const Args*... ptrs) {
......@@ -375,7 +382,6 @@ bool use_cudnn_norm_fwd();
bool use_cudnn_norm_bwd();
} // namespace normalization
} // namespace transformer_engine
#endif
......@@ -5,6 +5,7 @@
************************************************************************/
#include <transformer_engine/normalization.h>
#include <transformer_engine/transpose.h>
#include <cstdint>
#include <cstdlib>
......@@ -25,6 +26,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_block_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]);
......@@ -51,7 +57,9 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (use_cudnn_norm_fwd()) {
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
if (cudnn_backend) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
} else {
......@@ -59,6 +67,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr,
mu->data.dptr, rsigma->data.dptr);
}
bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Forward,
gamma.data.dtype, // wtype
......@@ -66,18 +78,31 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
z->data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
} else {
}
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
NVTE_CHECK(
!is_block_scaling(z->scaling_mode) || (!training || z->columnwise_scale_inv.dptr != nullptr),
"Columnwise scale_inv must be allocated for NormFwdTraining!");
plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr,
reinterpret_cast<void*>(const_cast<float*>(&epsilon)), rsigma->data.dptr,
workspace->data.dptr, stream);
// Compute FP8 transpose if required
if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) {
Tensor transpose_data;
transpose_data.data = z->columnwise_data;
transpose_data.scaling_mode = z->scaling_mode;
nvte_transpose(reinterpret_cast<NVTETensor>(z), reinterpret_cast<NVTETensor>(&transpose_data),
stream);
}
return;
}
......
......@@ -13,6 +13,7 @@
#include "../../common.h"
#include "../common.h"
#include "transformer_engine/normalization.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
......@@ -21,6 +22,11 @@ using namespace normalization;
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_block_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]);
......@@ -39,17 +45,21 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
CheckOutputTensor(*rsigma, "rsigma");
}
Tensor empty;
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (use_cudnn_norm_fwd()) {
bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode);
bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
if (cudnn_backend) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr);
}
auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward,
gamma.data.dtype, // wtype
......@@ -57,17 +67,29 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
z->data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
} else {
}
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr, nullptr,
NVTE_CHECK(
!is_block_scaling(z->scaling_mode) || (!training || z->columnwise_scale_inv.dptr != nullptr),
"Columnwise scale_inv must be allocated for NormFwdTraining!");
plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr /*beta*/, nullptr /*mu*/,
reinterpret_cast<void *>(const_cast<float *>(&epsilon)), rsigma->data.dptr,
workspace->data.dptr, stream);
// Compute FP8 transpose if required
if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) {
Tensor transpose_data;
transpose_data.data = z->columnwise_data;
transpose_data.scaling_mode = z->scaling_mode;
nvte_transpose(reinterpret_cast<NVTETensor>(z), reinterpret_cast<NVTETensor>(&transpose_data),
stream);
}
return;
......@@ -101,8 +123,6 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
CheckOutputTensor(*dgamma, "dgamma");
}
Tensor empty;
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
if (use_cudnn_norm_bwd()) {
......@@ -128,8 +148,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
return;
} else {
NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape());
plan->execute(x.data.dptr, gamma.data.dptr, nullptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, nullptr, dgamma->data.dptr, workspace->data.dptr, stream);
plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, workspace->data.dptr, stream);
}
return;
}
......
......@@ -39,19 +39,22 @@ class Format(Enum):
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
class _OverrideLinearPrecision(NamedTuple):
class Recipe:
"""
Whether or not the execute the `fprop`, `dgrad`, and `wgrad`
GEMMs in higher precision when using FP8.
Base recipe class.
"""
fprop: bool = False
dgrad: bool = False
wgrad: bool = False
def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)
def delayed(self):
"""Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling)
@dataclass()
class DelayedScaling:
class DelayedScaling(Recipe):
"""
Use the delayed scaling factor strategy. Use scale factor from previous
iteration and record amax history of `amax_history_len` steps.
......@@ -92,9 +95,6 @@ class DelayedScaling:
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False)
Whether or not to execute the `fprop`, `dgrad`, and `wgrad`
GEMMs (respectively) in higher precision when using FP8.
reduce_amax: bool, default = `True`
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
......@@ -137,7 +137,6 @@ class DelayedScaling:
fp8_format: Format = Format.HYBRID
amax_history_len: int = 1024
amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max"
override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision()
scaling_factor_compute_algo: Optional[Callable] = None
reduce_amax: bool = True
fp8_dpa: bool = False
......@@ -145,10 +144,6 @@ class DelayedScaling:
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert self.override_linear_precision in (
(False, False, False),
(False, False, True),
), "Only wgrad GEMM override is currently supported."
if self.interval >= 0:
warnings.warn(
"`interval` argument is deprecated and unused. "
......@@ -161,7 +156,32 @@ class DelayedScaling:
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"wgrad_override={self.override_linear_precision.wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
@dataclass()
class MXFP8BlockScaling(Recipe):
"""
Use the current scaling factor strategy.
Parameters
----------
margin : int, default = 0
Margin for the scaling factor computation.
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward
pass.
"""
margin: int = 0
fp8_format: Format = Format.E4M3
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
def __repr__(self) -> str:
return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]},"
......@@ -46,7 +46,6 @@ struct AmaxParam {
int num_scale = 0;
float* amax_history = nullptr;
float* scale = nullptr;
float* scale_inv = nullptr;
};
// dummy struct for kernel_bulk's other params
......@@ -83,10 +82,9 @@ constexpr size_t bsize = 256;
* Grid dims: num_scales x 1 x 1
*/
__global__ void __launch_bounds__(bsize)
kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr,
const unsigned char* scale_inv_mask_ptr, float* updated_amax_history_ptr,
float* updated_scale_ptr, float* updated_scale_inv_ptr, size_t amax_history_length,
size_t amax_history_stride, AmaxComputeAlgo amax_compute_algo, float scaled_max) {
kernel(const float* amax_history_ptr, const float* scale_ptr, float* updated_amax_history_ptr,
float* updated_scale_ptr, size_t amax_history_length, size_t amax_history_stride,
AmaxComputeAlgo amax_compute_algo, float scaled_max) {
const size_t tid = threadIdx.x;
const size_t bid = blockIdx.x;
......@@ -135,7 +133,7 @@ __global__ void __launch_bounds__(bsize)
}
}
// Update scale and scale inverse
// Update scale
if (tid == 0) {
// Update scale
float scale;
......@@ -152,15 +150,6 @@ __global__ void __launch_bounds__(bsize)
scale = std::numeric_limits<float>::max();
}
updated_scale_ptr[bid] = scale;
// Update scale inverse
float scale_inv;
if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) {
scale_inv = 1 / scale;
} else {
scale_inv = scale_inv_ptr[bid];
}
updated_scale_inv_ptr[bid] = scale_inv;
}
}
......@@ -232,7 +221,7 @@ __global__ void __launch_bounds__(bsize)
}
}
// Update scale and scale inverse
// Update scale
if (tid == 0) {
// Computing the scaling factor requires consideration of the following scenarios:
// 1. amax == 0:
......@@ -259,7 +248,6 @@ __global__ void __launch_bounds__(bsize)
scale = std::numeric_limits<float>::max();
}
p.param[bid].scale[count] = scale;
p.param[bid].scale_inv[count] = 1 / scale;
}
}
}
......@@ -268,23 +256,12 @@ __global__ void __launch_bounds__(bsize)
} // namespace
void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, const Tensor& scale_inv,
const Tensor& scale_inv_mask, Tensor* updated_amax_history_,
Tensor* updated_scale_, Tensor* updated_scale_inv_,
void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale,
Tensor* updated_amax_history_, Tensor* updated_scale_,
const std::string& amax_compute_algo, DType fp8_dtype, float margin,
cudaStream_t stream) {
auto& updated_amax_history = *updated_amax_history_;
auto& updated_scale = *updated_scale_;
auto& updated_scale_inv = *updated_scale_inv_;
// Number of elements in tensor
auto numel = [](const Tensor& tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
// Check tensors
NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(),
......@@ -293,18 +270,9 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons
const size_t num_scales = amax_history.data.shape[1];
NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ",
dtype_name(amax_history.data.dtype), ".");
NVTE_CHECK(numel(scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ",
numel(scale), ".");
NVTE_CHECK(scale.numel() == num_scales, "Expected ", num_scales, " elements, ", "but found ",
scale.numel(), ".");
NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), ".");
if (scale_inv_mask.data.dptr != nullptr) {
NVTE_CHECK(numel(scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ",
numel(scale_inv), ".");
NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32);
NVTE_CHECK(numel(scale_inv_mask) == num_scales, "Expected ", num_scales, " elements, ",
"but found ", numel(scale_inv_mask), ".");
NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ",
dtype_name(scale_inv_mask.data.dtype), ".");
}
NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ",
updated_amax_history.data.shape.size(), " dims.");
NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ",
......@@ -313,14 +281,10 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons
"but found ", updated_amax_history.data.shape[1]);
NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ",
dtype_name(updated_amax_history.data.dtype), ".");
NVTE_CHECK(numel(updated_scale) == num_scales, "Expected ", num_scales, " elements, ",
"but found ", numel(updated_scale), ".");
NVTE_CHECK(updated_scale.numel() == num_scales, "Expected ", num_scales, " elements, ",
"but found ", updated_scale.numel(), ".");
NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ",
dtype_name(updated_scale.data.dtype), ".");
NVTE_CHECK(numel(updated_scale_inv) == num_scales, "Expected ", num_scales, " elements, ",
"but found ", numel(updated_scale_inv), ".");
NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, "Got ",
dtype_name(updated_scale_inv.data.dtype), ".");
// amax value to use for updating scaling factor
AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
......@@ -340,11 +304,8 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons
const size_t grid_size = num_scales;
amax_and_scale_update_impl::kernel<<<grid_size, block_size, 0, stream>>>(
static_cast<const float*>(amax_history.data.dptr), static_cast<const float*>(scale.data.dptr),
static_cast<const float*>(scale_inv.data.dptr),
static_cast<const unsigned char*>(scale_inv_mask.data.dptr),
static_cast<float*>(updated_amax_history.data.dptr),
static_cast<float*>(updated_scale.data.dptr),
static_cast<float*>(updated_scale_inv.data.dptr), amax_history_length, num_scales,
static_cast<float*>(updated_scale.data.dptr), amax_history_length, num_scales,
amax_compute_algo_, scaled_max);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -352,7 +313,6 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons
void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
std::vector<Tensor*> amax_histories,
std::vector<Tensor*> scales,
std::vector<Tensor*> scale_invs,
const std::string& amax_compute_algo, DType fp8_dtype,
float margin, cudaStream_t stream) {
using namespace transformer_engine;
......@@ -370,15 +330,6 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
// Expected maximum value after scale is applied
const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);
// Number of elements in tensor
auto numel = [](const Tensor* tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor->data.shape) {
acc *= dim;
}
return acc;
};
// Number of tensors in the bulk
const size_t num_tensors = amax_histories.size();
size_t num_remaining_tensors = num_tensors;
......@@ -404,22 +355,21 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
dtype_name(amax_histories[i]->data.dtype), ".");
NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ",
amax_histories[i]->data.shape.size(), " dims");
NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, "Expected ",
NVTE_CHECK(amax_histories[i]->numel() == amax_history_length * num_scale, "Expected ",
amax_history_length * num_scale, " elements, ", "but found ",
numel(amax_histories[i]), ".");
amax_histories[i]->numel(), ".");
NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ",
dtype_name(scales[i]->data.dtype), ".");
NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(),
" dims");
NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ",
numel(scales[i]), ".");
NVTE_CHECK(scales[i]->numel() == num_scale, "Expected ", num_scale, " elements, ", "Found ",
scales[i]->numel(), ".");
// amax parameters
kernel_num_scales += num_scale;
p.param[pi].num_scale = num_scale;
p.param[pi].amax_history = static_cast<float*>(amax_histories[i]->data.dptr);
p.param[pi].scale = static_cast<float*>(scales[i]->data.dptr);
p.param[pi].scale_inv = static_cast<float*>(scale_invs[i]->data.dptr);
}
// Launch CUDA kernel
......@@ -441,34 +391,30 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
} // namespace transformer_engine
void nvte_delayed_scaling_recipe_amax_and_scale_update(
const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv,
const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale,
NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history,
NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
using namespace transformer_engine;
delayed_scaling_recipe::amax_and_scale_update(
*reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale),
*reinterpret_cast<const Tensor*>(scale_inv), *reinterpret_cast<const Tensor*>(scale_inv_mask),
reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale),
reinterpret_cast<Tensor*>(updated_scale_inv), amax_compute_algo,
static_cast<DType>(fp8_dtype), margin, stream);
amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
}
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
std::vector<NVTETensor> scales, std::vector<NVTETensor> scale_invs,
const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) {
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
float margin, cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction);
using namespace transformer_engine;
size_t num_tensors = amax_histories.size();
std::vector<Tensor*> t_amax_histories, t_scales, t_scale_invs;
std::vector<Tensor*> t_amax_histories, t_scales;
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i]));
t_scales.push_back(reinterpret_cast<Tensor*>(scales[i]));
t_scale_invs.push_back(reinterpret_cast<Tensor*>(scale_invs[i]));
}
delayed_scaling_recipe::amax_and_scale_update_after_reduction(
*reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales,
t_scale_invs, amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/swizzle.h>
#include <cassert>
#include <numeric>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"
namespace {
constexpr int TB_DIM = 32;
constexpr int NEW_SF_TILE_DIM_K = 16;
constexpr int N_SF_PER_TD_PER_TILE = 4;
// output is in ~K-major interleaved blocks
constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4;
constexpr int NEW_SF_TILE_DIM_M_I32 = 32;
template <typename LType>
__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
// inp, 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
// out, swapping byte to form new 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15]
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD;
int32_t new_regs[kVectorSize];
int32_t* regs = reinterpret_cast<int32_t*>(regs_vec);
#pragma unroll
for (int i = 0; i < N_TILE_PER_TD; i++) {
#pragma unroll
for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) {
new_regs[i * N_SF_PER_TD_PER_TILE + j] =
(((regs[i + 0 * N_TILE_PER_TD] >> 8 * j) & 0xFF)) |
(((regs[i + 1 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 8) |
(((regs[i + 2 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 16) |
(((regs[i + 3 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 24);
}
}
#pragma unroll
for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i];
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M,
const int K) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// input is in M-major
constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4;
constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K;
const int M_i32 = M / 4;
const int K_i32 = K;
int m_tiles_in_tb = N_TILE_PER_TD;
int k_tiles_in_tb = TB_DIM;
if (blockIdx.x == gridDim.x - 1) {
k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1;
}
if (blockIdx.y == gridDim.y - 1) {
m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1;
}
const int32_t* input_i32 = reinterpret_cast<const int32_t*>(input) +
blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 +
blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32;
int32_t* output_i32[N_TILE_PER_TD];
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
output_i32[i] = reinterpret_cast<int32_t*>(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 +
(blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32;
}
extern __shared__ int slm[];
// load, global -> regs
LType regs_vec[N_SF_PER_TD_PER_TILE];
if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 &&
threadIdx.y < k_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = __ldg(reinterpret_cast<const LType*>(
input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD));
}
// local shuffle
regs_shuffle_with_bit_shifts(regs_vec);
// store, regs -> shared
int tM = threadIdx.x * N_SF_PER_TD;
int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 +
tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
for (int i = 0; i < N_SF_PER_TD; i++) {
/* TODO rotate_i */
slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 +
((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] =
reinterpret_cast<int*>(regs_vec)[i];
}
}
__syncthreads();
// store, shared -> global
int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
#pragma unroll
for (int i = 0; i < m_tiles_in_tb; i++) {
__align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32[i]);
__align__(16) int4* slm_v4i =
reinterpret_cast<int4*>(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32);
#pragma unroll
for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4;
j += blockDim.x * blockDim.y) {
output_v4i[j] = slm_v4i[j];
}
}
}
template <typename LType>
__device__ inline void regs_shuffle(LType* regs_vec) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
if constexpr (N_TILE_PER_TD == 1) return;
constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD;
int32_t tmp[kVectorSize];
int32_t* ptr = reinterpret_cast<int32_t*>(regs_vec);
#pragma unroll
for (int i = 0; i < kVectorSize; i++)
tmp[i % N_TILE_PER_TD * N_SF_PER_TD_PER_TILE + i / N_TILE_PER_TD] = ptr[i];
#pragma unroll
for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i];
}
template <typename LType, int SF_TILE_DIM_M, int SF_TILE_DIM_K>
__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M,
const int K) {
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;
// input is in K-major
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M;
int n_tiles_in_tb = N_TILES_IN_TB;
const int K_i32 = K / 4;
if (blockIdx.x == gridDim.x - 1) {
n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1;
}
const int* input_i32 = reinterpret_cast<const int*>(input) +
blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB;
int* output_i32 = reinterpret_cast<int*>(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 +
blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32;
extern __shared__ int4 slm_v4i[];
// load, global -> regs
LType regs_vec[N_SF_PER_TD_PER_TILE];
if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) {
#pragma unroll
for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) {
regs_vec[i] = __ldg(reinterpret_cast<const LType*>(
input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD));
}
// shuffle regs
regs_shuffle<LType>(regs_vec);
// store, regs -> shared
#pragma unroll
for (int i = 0; i < N_TILE_PER_TD; i++) {
/* TODO rotate i */
slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] =
reinterpret_cast<int4*>(regs_vec)[i];
}
}
__syncthreads();
// store, shared -> global
int linear_id = threadIdx.y * blockDim.x + threadIdx.x;
__align__(16) int4* output_v4i = reinterpret_cast<int4*>(output_i32);
#pragma unroll
for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) {
output_v4i[i] = slm_v4i[i];
}
}
} // namespace
namespace transformer_engine {
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + ".");
}
// Do nothing if tensor is empty
if (input->data.numel() == 0) {
return;
}
CheckInputTensor(*input, "scaling_factor_input");
CheckInputTensor(*output, "scaling_factor_output");
auto& scaling_mode = input->scaling_mode;
// 1D block scaling, row-wise or colum-wise
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
const int m =
input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1];
const int k =
input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0];
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output->has_data()) {
NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(),
output->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(),
output->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
}
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
dim3 block_size(TB_DIM, TB_DIM);
if (input->has_data()) {
int vec_load_size = (num_tiles_k - 1) % 4 + 1;
/* there is no int3 and misaligned if using int4/int2 */
if (vec_load_size == 3) vec_load_size = 1;
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
switch (vec_load_size) {
case 4:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
break;
case 2:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
break;
case 1:
cudaFuncSetAttribute(swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->scale_inv.dptr,
output->scale_inv.dptr, m, k);
break;
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
}
if (input->has_columnwise_data()) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1;
if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size));
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
switch (vec_load_size) {
case 4:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
break;
case 2:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
break;
case 1:
cudaFuncSetAttribute(swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k);
break;
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
}
// 2D block scaling
} else {
NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans.");
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA Error: %s\n", cudaGetErrorString(err));
exit(-1);
}
}
} // namespace transformer_engine
/*
* WIP (Phuong):
* - Opt for bank conflicts
* - Adding swizzle for 2d-block scaling.
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_scaling_factors);
using namespace transformer_engine;
swizzle_scaling_factors(reinterpret_cast<const Tensor*>(input), reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -6,71 +6,196 @@
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include "common.h"
namespace transformer_engine {
size_t typeToSize(const transformer_engine::DType type) {
size_t typeToSize(const DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size;); // NOLINT(*)
}
bool is_fp8_dtype(const transformer_engine::DType t) {
return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2;
bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; }
std::string to_string(const DType type) {
switch (type) {
case DType::kByte:
return "Byte";
case DType::kBFloat16:
return "BFloat16";
case DType::kFloat16:
return "Float16";
case DType::kFloat32:
return "Float32";
case DType::kFloat8E4M3:
return "Float8E4M3";
case DType::kFloat8E5M2:
return "Float8E5M2";
case DType::kFloat8E8M0:
return "Float8E8M0";
case DType::kInt32:
return "Int32";
case DType::kInt64:
return "Int64";
default:
return concat_strings("Invalid type ", static_cast<int>(type));
}
}
std::string to_string(const NVTEScalingMode &mode) {
switch (mode) {
case NVTE_DELAYED_TENSOR_SCALING:
return "Delayed Tensor Scaling";
case NVTE_MXFP8_1D_SCALING:
return "MXFP8 1D Scaling";
case NVTE_INVALID_SCALING:
return "Invalid Scaling";
}
return "Invalid Scaling";
}
void CheckNoopTensor(const Tensor &t, const std::string &name) {
if (t.data.dptr != nullptr) {
NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(),
".");
NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name,
" noop. Expected kFloat32.");
}
}
void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
if (is_tensor_scaling(t.scaling_mode)) {
// per-tensor scaling
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected (1), got ",
t.columnwise_scale_inv.shape, ")");
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul, 4ul};
size_t expected_x, expected_y, alignment;
if (t.has_data()) {
alignment = block_alignment[0];
expected_x =
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1];
expected_y =
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(32)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
t.scale_inv.shape, ")");
}
if (t.has_columnwise_data()) {
alignment = block_alignment[1];
expected_x =
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(32)), alignment) * alignment;
alignment = block_alignment[0];
expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
t.columnwise_scale_inv.shape, ")");
}
}
}
}
void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.data.dtype;
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor input ", name,
"_scale_inverse has invalid dtype "
"(expected Float32 or Byte, got ",
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor input ", name,
"_columnwise_scale_inverse has invalid dtype "
"(expected Float32 or Byte, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
}
NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!");
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");
CheckScaleTensorShape(t, name);
}
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.data.dtype;
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, amax and scale_inv
NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor.");
NVTE_CHECK(t.amax.dtype == DType::kFloat32);
NVTE_CHECK(t.amax.shape == std::vector<size_t>{1});
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale.dtype == DType::kFloat32);
NVTE_CHECK(t.scale.shape == std::vector<size_t>{1});
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output ", name, " must have amax tensor");
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name,
" (expected 1 entry, got shape=", t.amax.shape, ")");
}
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor output ", name,
"_scale_inverse has invalid dtype "
"(expected Float32 or Float8E8M0, got ",
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor output ", name,
"_columnwise_scale_inverse has invalid dtype "
"(expected Float32 or Float8E8M0, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
}
if (!allow_empty) {
NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!");
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!");
}
CheckScaleTensorShape(t, name);
}
} // namespace transformer_engine
NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax,
float *scale, float *scale_inv) {
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->data.dptr = dptr;
ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
ret->data.dtype = static_cast<transformer_engine::DType>(dtype);
ret->amax.dptr = amax;
ret->scale.dptr = scale;
ret->scale_inv.dptr = scale_inv;
ret->scaling_mode = scaling_mode;
return ret;
}
......@@ -81,30 +206,65 @@ void nvte_destroy_tensor(NVTETensor tensor) {
}
NVTEDType nvte_tensor_type(const NVTETensor tensor) {
if (tensor == nullptr) return kNVTEFloat32;
return static_cast<NVTEDType>(
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->data.dtype);
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype());
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0};
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret;
// FP8 tensor keeps shape in rowwise data
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
}
// Get shape based on what data is available
if (t.has_data()) {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
}
if (t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
}
// Tensor has no data
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
}
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0};
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret;
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
}
size_t nvte_tensor_ndim(const NVTETensor tensor) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.shape.size();
}
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim);
return t.data.shape[dim];
}
size_t nvte_tensor_numel(const NVTETensor tensor) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
size_t numel = 1;
for (auto size : t.data.shape) {
......@@ -114,16 +274,25 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
}
size_t nvte_tensor_element_size(const NVTETensor tensor) {
if (tensor == nullptr) return sizeof(float);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return transformer_engine::typeToSize(t.data.dtype);
}
void *nvte_tensor_data(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr;
}
void *nvte_tensor_columnwise_data(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.columnwise_data.dptr;
}
float *nvte_tensor_amax(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!");
......@@ -131,6 +300,7 @@ float *nvte_tensor_amax(const NVTETensor tensor) {
}
float *nvte_tensor_scale(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!");
......@@ -138,12 +308,83 @@ float *nvte_tensor_scale(const NVTETensor tensor) {
}
float *nvte_tensor_scale_inv(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
"Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float *>(t.scale_inv.dptr);
}
void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.columnwise_scale_inv.dptr;
}
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0};
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret;
ret.data = t.scale_inv.shape.data();
ret.ndim = t.scale_inv.shape.size();
return ret;
}
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
const NVTEBasicTensor *param) {
NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated.");
auto &t = *reinterpret_cast<transformer_engine::Tensor *>(*tensor);
switch (param_name) {
case kNVTERowwiseData:
t.data = *param;
break;
case kNVTEColumnwiseData:
t.columnwise_data = *param;
break;
case kNVTEScale:
t.scale = *param;
break;
case kNVTEAmax:
t.amax = *param;
break;
case kNVTERowwiseScaleInv:
t.scale_inv = *param;
break;
case kNVTEColumnwiseScaleInv:
t.columnwise_scale_inv = *param;
break;
default:
NVTE_ERROR("Unknown tensor parameter!");
}
}
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, {nullptr, 0}};
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (param_name) {
case kNVTERowwiseData:
return t.data;
case kNVTEColumnwiseData:
return t.columnwise_data;
case kNVTEScale:
return t.scale;
case kNVTEAmax:
return t.amax;
case kNVTERowwiseScaleInv:
return t.scale_inv;
case kNVTEColumnwiseScaleInv:
return t.columnwise_scale_inv;
default:
NVTE_ERROR("Unknown tensor parameter!");
}
}
NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.scaling_mode;
}
void nvte_tensor_pack_create(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
......@@ -156,3 +397,18 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
delete t;
}
}
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
// Zero out tensor data if allocated
if (t.data.dptr != nullptr) {
size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor);
cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream);
}
// Set amax to 0 if allocated
if (t.amax.dptr != nullptr) {
float zero = 0.0f;
cudaMemcpyAsync(t.amax.dptr, &zero, sizeof(float), cudaMemcpyHostToDevice, stream);
}
cudaStreamSynchronize(stream);
}
......@@ -10,12 +10,12 @@
#include <algorithm>
#include "../common.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
#include "cast_transpose.h"
namespace transformer_engine {
namespace transformer_engine::detail {
namespace {
......@@ -217,54 +217,39 @@ __global__ void __launch_bounds__(block_size) cast_transpose_general_kernel(
} // namespace
void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_,
Tensor *transposed_output_, cudaStream_t stream) {
Tensor &cast_output = *cast_output_;
Tensor &transposed_output = *transposed_output_;
void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) {
Tensor &output = *output_;
// Check no-op flag
if (noop.data.dptr != nullptr) {
size_t numel = 1;
for (const auto &dim : noop.data.shape) {
numel *= dim;
}
NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr);
}
// Check tensor dims
CheckNoopTensor(noop, "cast_transpose_noop");
CheckInputTensor(input, "cast_transpose_input");
CheckOutputTensor(cast_output, "cast_output");
CheckOutputTensor(transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions.");
NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output.");
NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output.");
NVTE_CHECK(transposed_output.data.shape[0] == row_length,
"Wrong dimension of transposed output.");
NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output.");
// Check tensor pointers
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(cast_output.data.dptr != nullptr, "Cast output is not allocated.");
NVTE_CHECK(transposed_output.data.dptr != nullptr, "Transposed output is not allocated.");
NVTE_CHECK(cast_output.data.dtype == transposed_output.data.dtype,
CheckOutputTensor(output, "cast_transpose_output");
// Check that inputs and outputs are available
NVTE_CHECK(input.has_data(), "Input is not allocated");
NVTE_CHECK(output.has_data(), "Output rowwise data is not allocated");
NVTE_CHECK(output.has_columnwise_data(), "Output columnwise is not allocated");
// Flatten tensor to 2D
NVTE_CHECK(input.data.shape == output.data.shape,
"Input and output shapes do not match (input=", input.data.shape,
", output=", output.data.shape);
const size_t row_length = input.flat_last_dim();
const size_t num_rows = input.flat_first_dim();
NVTE_CHECK(output.flat_first_dim() == num_rows && output.flat_last_dim() == row_length,
"Invalid output dimensions (expected ", std::vector<size_t>{num_rows, row_length},
", got ", std::vector<size_t>{output.flat_first_dim(), output.flat_last_dim()}, ")");
// Check that cast and transposed output data matches
NVTE_CHECK(output.data.dtype == output.columnwise_data.dtype,
"Cast and transposed output types must match.");
NVTE_CHECK(cast_output.amax.dptr == transposed_output.amax.dptr,
"Cast and transposed outputs need to share amax tensor.");
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor.");
NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr,
NVTE_CHECK(output.scale_inv.dptr == output.columnwise_scale_inv.dptr,
"Cast and transposed outputs need to share scale-inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
input.dtype(), InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output.data.dtype, OutputType,
output.dtype(), OutputType,
if (is_delayed_tensor_scaling(output.scaling_mode)) {
constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
constexpr size_t itype_size = sizeof(InputType);
......@@ -326,12 +311,11 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length,
num_rows);
static_cast<OutputType *>(output.data.dptr),
static_cast<OutputType *>(output.columnwise_data.dptr),
static_cast<const CType *>(output.scale.dptr),
static_cast<CType *>(output.amax.dptr),
static_cast<CType *>(output.scale_inv.dptr), row_length, num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
......@@ -343,33 +327,33 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output
<<<num_blocks, block_size, 0, stream>>>(
static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr),
static_cast<CType *>(cast_output.scale_inv.dptr), row_length, num_rows);
static_cast<OutputType *>(output.data.dptr),
static_cast<OutputType *>(output.columnwise_data.dptr),
static_cast<const CType *>(output.scale.dptr),
static_cast<CType *>(output.amax.dptr),
static_cast<CType *>(output.scale_inv.dptr), row_length, num_rows);
}
} else {
NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode));
}); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
} // namespace transformer_engine::detail
void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output,
NVTETensor transposed_output, cudaStream_t stream) {
void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine;
auto noop = Tensor();
cast_transpose(*reinterpret_cast<const Tensor *>(input), noop,
reinterpret_cast<Tensor *>(cast_output),
reinterpret_cast<Tensor *>(transposed_output), stream);
transformer_engine::detail::cast_transpose(*reinterpret_cast<const Tensor *>(input), noop,
reinterpret_cast<Tensor *>(output), stream);
}
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop,
NVTETensor cast_output, NVTETensor transposed_output,
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
reinterpret_cast<Tensor *>(cast_output),
reinterpret_cast<Tensor *>(transposed_output), stream);
transformer_engine::detail::cast_transpose(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(noop),
reinterpret_cast<Tensor *>(output), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#include "../common.h"
namespace transformer_engine::detail {
void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream);
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP &)>
void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output,
Tensor *dbias, Tensor *workspace, cudaStream_t stream);
template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType, const ParamOP &),
ComputeType (*OP2)(ComputeType, const ParamOP &)>
void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output,
cudaStream_t stream);
} // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
......@@ -8,18 +8,19 @@
#include <transformer_engine/transpose.h>
#include <cfloat>
#include <iostream>
#include <functional>
#include <numeric>
#include <type_traits>
#include "../common.h"
#include "../util/math.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
#include "cast_transpose.h"
namespace transformer_engine {
namespace {
namespace detail {
// String with RTC kernel implementation
#include "string_code_transpose_rtc_cast_transpose_fusion_cu.h"
......@@ -177,16 +178,31 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor *workspace, const int nvec_out) {
const size_t row_length = cast_output.data.shape[1];
const size_t num_rows = cast_output.data.shape[0];
const size_t row_length = cast_output.flat_last_dim();
const size_t num_rows = cast_output.flat_first_dim();
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {num_rows_partial_dbias, row_length};
workspace->data.dtype = DType::kFloat32;
} else {
// Check that workspace matches expected size
const size_t workspace_size =
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) *
typeToSize(workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")");
}
}
template <int nvec, typename ComputeType, typename OutputType>
......@@ -248,11 +264,13 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt
reduce_dbias_num_rows);
}
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename Param, int nvec_in,
int nvec_out, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename Param,
int nvec_in, int nvec_out, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP &)>
__global__ void __launch_bounds__(cast_transpose_num_threads)
cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length,
const size_t num_rows, const size_t num_tiles) {
static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive");
using IType = typename Param::InputType;
using IType2 = typename Param::InputType2;
using OType = typename Param::OutputType;
......@@ -373,6 +391,8 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
if constexpr (IS_DACT) {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else if constexpr (IS_ACT) {
after_dact[j].data.elt[k] = OP(in[current_in ^ 1][j].data.elt[k], {});
} else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
}
......@@ -449,78 +469,96 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
}
static const char *ActTypeToString[] = {
"NoAct", // 0
"Sigmoid", // 1
"GeLU", // 2
"QGeLU", // 3
"SiLU", // 4
"ReLU", // 5
"SReLU" // 6
"none", // 0
"sigmoid", // 1
"dsigmoid", // 2
"gelu", // 3
"dgelu", // 4
"qgelu", // 5
"dqgelu", // 6
"silu", // 7
"dsilu", // 8
"relu", // 9
"drelu", // 10
"srelu", // 11
"dsrelu" // 12
};
template <typename ComputeType, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
int get_dactivation_type() {
if (OP == &sigmoid<ComputeType, ComputeType>) {
return 1;
} else if (OP == &dgelu<ComputeType, ComputeType>) {
return 2;
} else if (OP == &dqgelu<ComputeType, ComputeType>) {
return 3;
} else if (OP == &dsilu<ComputeType, ComputeType>) {
return 4;
} else if (OP == &drelu<ComputeType, ComputeType>) {
return 5;
} else if (OP == &dsrelu<ComputeType, ComputeType>) {
return 6;
} else {
return 0;
constexpr int get_activation_type() {
constexpr decltype(OP) ActivationList[] = {
nullptr, // 0
&sigmoid<ComputeType, ComputeType>, // 1
&dsigmoid<ComputeType, ComputeType>, // 2
&gelu<ComputeType, ComputeType>, // 3
&dgelu<ComputeType, ComputeType>, // 4
&qgelu<ComputeType, ComputeType>, // 5
&dqgelu<ComputeType, ComputeType>, // 6
&silu<ComputeType, ComputeType>, // 7
&dsilu<ComputeType, ComputeType>, // 8
&relu<ComputeType, ComputeType>, // 9
&drelu<ComputeType, ComputeType>, // 10
&srelu<ComputeType, ComputeType>, // 11
&dsrelu<ComputeType, ComputeType> // 12
};
#pragma unroll
for (int i = 0; i < sizeof(ActivationList) / sizeof(ActivationList[0]); ++i) {
if (OP == ActivationList[i]) {
return i;
}
}
return 0;
}
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP &)>
void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *cast_output,
Tensor *transposed_output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) {
if (workspace->data.dptr != nullptr) {
void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output,
Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
// Check tensors, unless querying dbias workspace
if (!IS_DBIAS || workspace->data.dptr != nullptr) {
CheckInputTensor(input, "cast_transpose_fused_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
if constexpr (IS_DBIAS) CheckOutputTensor(*dbias, "dbias");
if constexpr (IS_DACT) CheckInputTensor(act_input, "act_input");
CheckOutputTensor(*output, "output");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias != nullptr && dbias->has_data());
CheckOutputTensor(*dbias, "dbias");
}
if constexpr (IS_DACT) {
NVTE_CHECK(act_input != nullptr && act_input->has_data());
CheckInputTensor(*act_input, "act_input");
}
}
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
// Check that inputs and outputs are available
NVTE_CHECK(input.has_data(), "Input is not allocated");
NVTE_CHECK(output->has_data(), "Output rowwise data is not allocated");
NVTE_CHECK(output->has_columnwise_data(), "Output columnwise data is not allocated");
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
// Flatten tensor to 2D
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes do not match (input=", input.data.shape,
", output=", output->data.shape);
const size_t row_length = input.flat_last_dim();
const size_t num_rows = input.flat_first_dim();
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
// Check that cast and transposed output data matches
NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype,
"Cast and transposed output types must match.");
NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr,
"Cast and transposed outputs need to share scale-inverse tensor.");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{row_length}, "Wrong shape of DBias.");
}
if constexpr (IS_DACT) {
NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match.");
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
input.dtype(), InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output->data.dtype, OutputType, using InputType2 = InputType;
output->dtype(), OutputType, using InputType2 = InputType;
using Param = CTDBiasDActParam<InputType, InputType2, OutputType, ComputeType>;
constexpr int itype_size = sizeof(InputType);
......@@ -584,8 +622,9 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *
if (!jit_compiled) {
num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block);
} if constexpr (IS_DBIAS) {
// Check workspace size
populate_cast_transpose_dbias_workspace_config(*output, workspace, nvec_out);
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
}
}
......@@ -631,15 +670,15 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(transposed_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(transposed_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.output_c = reinterpret_cast<OutputType *>(output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(output->columnwise_data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(output->scale_inv.dptr);
if constexpr (IS_DBIAS) {
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
} if constexpr (IS_DACT) {
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
param.act_input = reinterpret_cast<const InputType2 *>(act_input->data.dptr);
}
// Runtime-compiled tuned kernel
......@@ -648,9 +687,9 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *
constexpr const char *itype2_name = TypeInfo<InputType2>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
int dActType = 0;
if constexpr (IS_DACT) {
dActType = get_dactivation_type<ComputeType, ParamOP, OP>();
int actType = 0;
if constexpr (IS_DACT || IS_ACT) {
actType = get_activation_type<ComputeType, ParamOP, OP>();
}
// Compile NVRTC kernel if needed and launch
......@@ -660,7 +699,8 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *
",itype=",
itype_name, ",itype2=", itype2_name, ",otype=", otype_name,
",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS,
",IS_DACT=", IS_DACT, ",dactivationType=", ActTypeToString[dActType]);
",IS_DACT=", IS_DACT, ",IS_ACT=", IS_ACT,
",activationType=", ActTypeToString[actType]);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu;
......@@ -673,7 +713,8 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *
code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads);
code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS);
code = regex_replace(code, "__IS_DACT__", IS_DACT);
code = regex_replace(code, "__DACTIVATION_TYPE__", dActType);
code = regex_replace(code, "__IS_ACT__", IS_ACT);
code = regex_replace(code, "__ACTIVATION_TYPE__", actType);
rtc_manager.compile(
kernel_label, "cast_transpose_fusion_kernel_optimized", code,
......@@ -695,11 +736,11 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, ComputeType, Param,
nvec_in, nvec_out, Empty, OP>,
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType,
Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in,
nvec_out, Empty, OP>
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Param,
nvec_in, nvec_out, Empty, OP>
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, num_tiles);
}
......@@ -1101,43 +1142,39 @@ __global__ void __launch_bounds__(cast_transpose_num_threads)
template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType, const ParamOP &),
ComputeType (*OP2)(ComputeType, const ParamOP &)>
void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input,
Tensor *cast_output, Tensor *transposed_output,
void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "dgated_act_cast_transpose_input");
CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input");
CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output");
CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output");
CheckOutputTensor(*output, "dgated_act_cast_transpose_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(output->has_data() && output->has_columnwise_data(),
"Both rowwise and columnwise data need to be allocated.");
NVTE_CHECK(output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(output->columnwise_data.shape.size() == 2, "T output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(output->data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(output->data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(output->columnwise_data.shape[0] == row_length * 2, "Wrong dimension of T output.");
NVTE_CHECK(output->columnwise_data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
input.dtype(), InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output->data.dtype, OutputType, using InputType2 = InputType;
output->dtype(), OutputType, using InputType2 = InputType;
/* dact fusion kernel uses more registers */
constexpr int desired_load_size_dact = 4;
constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType);
......@@ -1168,11 +1205,11 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), row_length, num_rows,
reinterpret_cast<OutputType *>(output->data.dptr),
reinterpret_cast<OutputType *>(output->columnwise_data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows,
n_tiles);
} else {
cudaFuncSetAttribute(
......@@ -1184,194 +1221,193 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), row_length, num_rows,
reinterpret_cast<OutputType *>(output->data.dptr),
reinterpret_cast<OutputType *>(output->columnwise_data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows,
n_tiles);
}); // NOLINT(*)
); // NOLINT(*)
}
} // namespace
// Explicit template instantiation
template void cast_transpose_fused<true, false, false, float, transformer_engine::Empty, nullptr>(
const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t);
#define NVTE_INSTANTIATE_ACTIVATION(op) \
template void cast_transpose_fused<false, false, true, float, transformer_engine::Empty, \
transformer_engine::op<float, float>>( \
const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); \
template void cast_transpose_fused<false, true, false, float, transformer_engine::Empty, \
transformer_engine::d##op<float, float>>( \
const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t);
NVTE_INSTANTIATE_ACTIVATION(relu);
NVTE_INSTANTIATE_ACTIVATION(srelu);
NVTE_INSTANTIATE_ACTIVATION(gelu);
NVTE_INSTANTIATE_ACTIVATION(qgelu);
NVTE_INSTANTIATE_ACTIVATION(silu);
#undef NVTE_INSTANTIATE_ACTIVATION
} // namespace detail
} // namespace transformer_engine
using ComputeType = typename transformer_engine::fp32;
void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(activation_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(activation_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr auto dActivation = &dgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(act_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr auto dActivation = &dsilu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(silu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsilu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(silu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_drelu);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr auto dActivation = &drelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(relu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(relu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr auto dActivation = &dsrelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(srelu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(srelu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
constexpr auto dActivation = &dqgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(qgelu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(qgelu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr auto dActivation = &dgelu<fp32, fp32>;
constexpr auto Activation = &gelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
dgated_act_cast_transpose<ComputeType, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
reinterpret_cast<Tensor *>(output), stream);
}
void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu_cast_transpose);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr auto dActivation = &dsilu<fp32, fp32>;
constexpr auto Activation = &silu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
dgated_act_cast_transpose<ComputeType, Empty, dsilu<fp32, fp32>, silu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(swiglu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
reinterpret_cast<Tensor *>(output), stream);
}
void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu_cast_transpose);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr auto dActivation = &drelu<fp32, fp32>;
constexpr auto Activation = &relu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
dgated_act_cast_transpose<ComputeType, Empty, drelu<fp32, fp32>, relu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
reinterpret_cast<Tensor *>(output), stream);
}
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu_cast_transpose);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr auto dActivation = &dsrelu<fp32, fp32>;
constexpr auto Activation = &srelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
dgated_act_cast_transpose<ComputeType, Empty, dsrelu<fp32, fp32>, srelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
reinterpret_cast<Tensor *>(output), stream);
}
void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu_cast_transpose);
using namespace transformer_engine;
using namespace transformer_engine::detail;
constexpr auto dActivation = &dqgelu<fp32, fp32>;
constexpr auto Activation = &qgelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
dgated_act_cast_transpose<ComputeType, Empty, dqgelu<fp32, fp32>, qgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
reinterpret_cast<Tensor *>(output), stream);
}
......@@ -195,42 +195,44 @@ __global__ void __launch_bounds__(threads_per_block)
} // namespace
void multi_cast_transpose(const std::vector<Tensor*> input_list,
std::vector<Tensor*> cast_output_list,
std::vector<Tensor*> transposed_output_list, cudaStream_t stream) {
void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match");
NVTE_CHECK(output_list.size() == input_list.size(),
"Number of input and output tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType itype = input_list[0]->data.dtype;
DType otype = cast_output_list[0]->data.dtype;
DType otype = output_list[0]->dtype();
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& cast_output = *cast_output_list[tensor_id];
const auto& transposed_output = *transposed_output_list[tensor_id];
const auto& output = *output_list[tensor_id];
CheckInputTensor(input, "multi_cast_transpose_input_" + std::to_string(tensor_id));
CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id));
CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id));
CheckInputTensor(output, "multi_cast_transpose_output_" + std::to_string(tensor_id));
//std::cout << *static_cast<char*>(output.data.dptr) << std::endl;
NVTE_CHECK(output.has_data() && output.has_columnwise_data(),
"Both rowwise and columnwise output data needs to be allocated.");
NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match.");
NVTE_CHECK(cast_output.data.dtype == otype, "C output tensor types do not match.");
NVTE_CHECK(transposed_output.data.dtype == otype, "T output tensor types do not match.");
NVTE_CHECK(output.data.dtype == otype, "C output tensor types do not match.");
NVTE_CHECK(output.data.dtype == otype, "T output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
NVTE_CHECK(cast_output.data.shape == input.data.shape,
"C output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.data.shape.size() == 2,
"T output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.data.shape[0] == input.data.shape[1],
"T output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.data.shape[1] == input.data.shape[0],
"T output tensor shape does not match input tensor.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions, but shape is ",
input.data.shape);
NVTE_CHECK(output.data.shape == input.data.shape, "C output tensor shape ", output.data.shape,
"does not match input tensor shape ", input.data.shape);
NVTE_CHECK(output.columnwise_data.shape.size() == 2, "T output tensor shape ",
output.columnwise_data.shape, "does not match input tensor shape ",
input.data.shape);
NVTE_CHECK(output.columnwise_data.shape[0] == input.data.shape[1], "T output tensor shape ",
output.columnwise_data.shape, "does not match input tensor shape ",
input.data.shape);
NVTE_CHECK(output.columnwise_data.shape[1] == input.data.shape[0], "T output tensor shape ",
output.columnwise_data.shape, "does not match input tensor shape ",
input.data.shape);
}
// Input matrices are divided into tiles
......@@ -287,11 +289,11 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->data.dptr;
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr;
kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr;
kernel_args.output_c_list[pos] = output_list[tensor_id]->data.dptr;
kernel_args.output_t_list[pos] = output_list[tensor_id]->columnwise_data.dptr;
kernel_args.scale_list[pos] = output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = output_list[tensor_id]->amax.dptr;
kernel_args.scale_inv_list[pos] = output_list[tensor_id]->scale_inv.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
......@@ -327,15 +329,13 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
} // namespace transformer_engine
void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
NVTETensor* cast_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream) {
NVTETensor* output_list, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine;
std::vector<Tensor*> input_list_, cast_output_list_, transposed_output_list_;
std::vector<Tensor*> input_list_, output_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i]));
transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i]));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i]));
}
multi_cast_transpose(input_list_, cast_output_list_, transposed_output_list_, stream);
multi_cast_transpose(input_list_, output_list_, stream);
}
......@@ -22,7 +22,9 @@ constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__;
constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__;
constexpr bool IS_DBIAS = __IS_DBIAS__;
constexpr bool IS_DACT = __IS_DACT__;
constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__;
constexpr bool IS_ACT = __IS_ACT__;
static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive");
constexpr size_t ACT_TYPE = __ACTIVATION_TYPE__;
constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType);
constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType);
......@@ -33,14 +35,20 @@ using OVec = Vec<OType, NVEC_OUT>;
using Param = CTDBiasDActParam<IType, IType2, OType, CType>;
using OP = CType (*)(const CType, const Empty &);
constexpr OP Activation[] = {
constexpr OP ActivationList[] = {
nullptr, // 0
&dsigmoid<CType, CType>, // 1
&dgelu<CType, CType>, // 2
&dqgelu<CType, CType>, // 3
&dsilu<CType, CType>, // 4
&drelu<CType, CType>, // 5
&dsrelu<CType, CType> // 6
&sigmoid<CType, CType>, // 1
&dsigmoid<CType, CType>, // 2
&gelu<CType, CType>, // 3
&dgelu<CType, CType>, // 4
&qgelu<CType, CType>, // 5
&dqgelu<CType, CType>, // 6
&silu<CType, CType>, // 7
&dsilu<CType, CType>, // 8
&relu<CType, CType>, // 9
&drelu<CType, CType>, // 10
&srelu<CType, CType>, // 11
&dsrelu<CType, CType> // 12
};
} // namespace
......@@ -175,7 +183,10 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
if constexpr (IS_DACT) {
in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k]) *
Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {});
ActivationList[ACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {});
} else if constexpr (IS_ACT) {
in_cast_fp32[j].data.elt[k] =
ActivationList[ACT_TYPE](in[current_in ^ 1][j].data.elt[k], {});
} else {
in_cast_fp32[j].data.elt[k] = static_cast<CType>(in[current_in ^ 1][j].data.elt[k]);
}
......
......@@ -205,17 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match.");
// Number of elements in tensor
auto numel = [](const Tensor &tensor) -> size_t {
size_t acc = 1;
for (const auto &dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1, "Expected 1 element, ", "but found ", numel(noop), ".");
NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr);
}
......
......@@ -8,8 +8,8 @@
#include <transformer_engine/transpose.h>
#include <cfloat>
#include <iostream>
#include <type_traits>
#include <functional>
#include <numeric>
#include "../common.h"
#include "../utils.cuh"
......@@ -376,8 +376,24 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
if (workspace->data.dptr == nullptr) {
// Set workspace size
workspace->data.shape = {num_rows_partial_dbias, row_length};
workspace->data.dtype = DType::kFloat32;
} else {
// Check that workspace matches expected size
const size_t workspace_size =
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) *
typeToSize(workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")");
}
}
template <typename BiasType>
......@@ -426,10 +442,9 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
constexpr int nvec_in = desired_load_size / type_size;
constexpr int nvec_out = desired_store_size / type_size;
if (workspace->data.dptr == nullptr) {
// Check workspace size
populate_transpose_dbias_workspace_config(input, workspace, nvec_out);
return;
}
if (workspace->data.dptr == nullptr) { return; }
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
......
......@@ -4,88 +4,144 @@
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
#include <cfloat>
#include <limits>
#include <string>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize);
using namespace transformer_engine;
namespace transformer_engine {
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
namespace detail {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, nullptr, output,
dbias, workspace, stream);
}
struct Empty {};
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_noop);
using namespace transformer_engine;
__device__ inline fp32 identity(fp32 value, const Empty &) { return value; }
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
struct DequantizeParam {
const fp32 *scale_inv;
};
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, noop, output,
dbias, workspace, stream);
}
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias);
using namespace transformer_engine;
__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) {
return value * (*(param.scale_inv));
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr const NVTETensor activation_input = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, nullptr, output, dbias, workspace, stream);
}
} // namespace detail
void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision.");
NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {},
stream);); // NOLINT(*)
); // NOLINT(*)
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
}
void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
}
} // namespace transformer_engine
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
}
void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize);
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
stream);
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
}
void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize);
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
stream);
detail::dequantize_helper(*reinterpret_cast<const Tensor *>(input),
reinterpret_cast<Tensor *>(output), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file cast_gated_kernels.cuh
* \brief CUDA gated activations kernels to cast to/from FP8/MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <cfloat>
#include "../common.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "ptx.cuh"
namespace transformer_engine {
template <typename T1, typename T2>
__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) {
return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}
namespace gated_kernels {
constexpr size_t ALIGNMENT_SIZE = 128;
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 512;
constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X;
constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128
constexpr size_t BUFFERS_NUM = 2;
constexpr size_t BUFFER_DIM_Y = 32;
constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128
constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32
constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128
constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32
static_assert(ITERATIONS >= 1);
__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); }
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
const __grid_constant__ CUtensorMap tensor_map_input_gate,
const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const int thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
extern __shared__ char dshmem_unaligned[];
const uint64_t dshmem_unaligned_as_uint = reinterpret_cast<uint64_t>(dshmem_unaligned);
const uint64_t dshmem_aligned_as_uint =
DIVUP(dshmem_unaligned_as_uint, static_cast<uint64_t>(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE;
char *dshmem = reinterpret_cast<char *>(dshmem_aligned_as_uint);
constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
constexpr size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0;
constexpr size_t in_act_mem = buff_size_aligned_in;
constexpr size_t in_gate_mem = buff_size_aligned_in;
constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out;
constexpr size_t out_gate_mem = buff_size_aligned_out;
constexpr size_t out_mem = out_act_mem + out_gate_mem;
// const size_t in_transaction_size = grad_mem + in_mem;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem);
OType *out_act_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem);
OType *out_gate_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem);
// uint64_t *mbar = reinterpret_cast<uint64_t *>(dshmem + grad_mem + in_mem + out_mem);
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act = reinterpret_cast<const uint64_t *>(&tensor_map_output_act);
const uint64_t *TMAP_output_gate = reinterpret_cast<const uint64_t *>(&tensor_map_output_gate);
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
initialize_barriers<ITERATIONS, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
// Prefetch data of the first stage
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh,
TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate,
chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
} else {
copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh,
TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0],
is_master_thread);
}
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const int buff = it % BUFFERS_NUM;
const int next_it = it + 1;
if (next_it < ITERATIONS) {
const int next_buff = next_it % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y,
&in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y,
in_transaction_size, &mbar[next_it], is_master_thread);
} else {
copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x,
chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, in_transaction_size,
&mbar[next_it], is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems;
IType *in_act_sh_curr = in_act_sh + buff * buff_elems;
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
}
float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = act_x * grad_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate));
} else {
const float after_act = ActOP(act_elt, {}) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act));
}
}
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence)
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
// dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_sh_curr));
if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_sh_curr));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
// Destroy the barriers. This invalidates the memory region of the barrier.
// If further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
const __grid_constant__ CUtensorMap tensor_map_input_gate,
const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_act_colwise,
const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128
const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y;
const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X;
const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y;
const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X;
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const int thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X;
const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols);
extern __shared__ char dshmem_unaligned[];
const uint64_t dshmem_unaligned_as_uint = reinterpret_cast<uint64_t>(dshmem_unaligned);
const uint64_t dshmem_aligned_as_uint =
DIVUP(dshmem_unaligned_as_uint, static_cast<uint64_t>(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE;
char *dshmem = reinterpret_cast<char *>(dshmem_aligned_as_uint);
const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_elems_total = BUFFERS_NUM * buff_elems;
const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
const size_t out_mem = out_act_mem + out_gate_mem;
// const size_t in_transaction_size = grad_mem + in_mem;
const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem);
OType *out_act_rowwise_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem);
OType *out_gate_rowwise_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem);
OType *out_act_colwise_sh = out_act_rowwise_sh;
OType *out_gate_colwise_sh = out_gate_rowwise_sh;
if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) {
out_act_colwise_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem);
out_gate_colwise_sh =
reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem + out_act_mem);
}
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act_rowwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_rowwise);
const uint64_t *TMAP_output_gate_rowwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_rowwise);
const uint64_t *TMAP_output_act_colwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_colwise);
const uint64_t *TMAP_output_gate_colwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_colwise);
__shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X];
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
const bool is_master_thread = (threadIdx.x == 0);
if (is_master_thread) {
// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate.
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK);
}
ptx::fence_proxy_async_shared_cta();
}
// Syncthreads so initialized barrier is visible to all threads.
__syncthreads();
int parity = 0;
// Prefetch data of the first stage
if (is_master_thread) {
// Initiate bulk tensor copy
// Grad
if constexpr (IS_DGATED) {
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_grad_sh[0]),
TMAP_grad_in, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
}
// Act
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_act_sh[0]),
TMAP_in_act, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
// Gate
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_gate_sh[0]),
TMAP_in_gate, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(&mbar[0]);
}
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
const int buff = it % BUFFERS_NUM;
const int next_it = it + 1;
const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y;
if (next_it < ITERATIONS) {
if (is_master_thread) {
const int next_buff = next_it % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
// Initiate bulk tensor copy
if constexpr (IS_DGATED) {
// Grad
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
}
// Act
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_act_sh[next_buff * buff_elems]), TMAP_in_act,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
// Gate
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size);
} else {
// Other threads just arrive
ptx::mbarrier_arrive(&mbar[next_it]);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems;
IType *in_act_sh_curr = in_act_sh + buff * buff_elems;
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems;
OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems;
OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems;
OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems;
// Assuming one iteration covers exactly 32 rows
const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it;
const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y;
float after_dact_reg[BUFFER_STAGES_NUM];
float after_dgate_reg[BUFFER_STAGES_NUM];
float thread_Y_mx_block_amax = 0.0f;
float thread_Y_mx_block_amax_gate = 0.0f;
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = (row >= rows);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
}
after_dact_reg[stage] = dact_x * grad_elt * gate_elt;
after_dgate_reg[stage] = act_x * grad_elt;
} else {
after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt;
}
if constexpr (USE_ROWWISE_SCALING) {
if constexpr (IS_DGATED) {
// dgate
float amax = fabsf(after_dgate_reg[stage]);
const float mx_block_X_amax = warp_reduce_max_broadcast(amax);
const e8m0_t biased_exponent_X =
float_to_e8m0(mx_block_X_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X);
out_gate_rowwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal_X * after_dgate_reg[stage]);
// Only single thread writes the computed scaling factor
if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y;
const int global_scales_offset_X =
scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent_X;
}
}
float amax = fabsf(after_dact_reg[stage]);
const float mx_block_X_amax = warp_reduce_max_broadcast(amax);
const e8m0_t biased_exponent_X =
float_to_e8m0(mx_block_X_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X);
out_act_rowwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal_X * after_dact_reg[stage]);
// Only single thread writes the computed scaling factor
if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y;
const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent_X;
}
}
if constexpr (USE_COLWISE_SCALING) {
__builtin_assume(thread_Y_mx_block_amax >= 0);
__builtin_assume(thread_Y_mx_block_amax_gate >= 0);
thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage]));
if constexpr (IS_DGATED) {
thread_Y_mx_block_amax_gate =
fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage]));
}
}
}
if constexpr (USE_COLWISE_SCALING) {
const bool row_out_of_bounds = (row_base >= rows);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
if constexpr (IS_DGATED) {
// Colwise max reduction of the amax element
if (tid_Y > 0) {
stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate;
}
__syncthreads();
if (tid_Y == 0) {
#pragma unroll
for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) {
thread_Y_mx_block_amax_gate =
fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]);
}
stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax
}
__syncthreads();
const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax
// For the scaling along both dimensions, the thread amax is already computed in ROWWISE section
if constexpr (!USE_ROWWISE_SCALING) {
__builtin_assume(mx_block_Y_amax >= 0);
}
const e8m0_t biased_exponent =
float_to_e8m0(mx_block_Y_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal = exp2f_rcp(biased_exponent);
// Only single thread writes the computed scaling factor
// Also assuming one iteration covers exactly 32 rows
if ((tid_Y == 0) && !out_of_bounds) {
const int global_scales_offset_Y = iteration_scale_colwise_offset_Y;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
}
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
out_gate_colwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal * after_dgate_reg[stage]);
}
}
// Colwise max reduction of the amax element
if (tid_Y > 0) {
stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax;
}
__syncthreads();
if (tid_Y == 0) {
#pragma unroll
for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) {
thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]);
}
stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax
}
__syncthreads();
const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax
// For the scaling along both dimensions, the thread amax is already computed in ROWWISE section
if constexpr (!USE_ROWWISE_SCALING) {
__builtin_assume(mx_block_Y_amax >= 0);
}
const e8m0_t biased_exponent =
float_to_e8m0(mx_block_Y_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal = exp2f_rcp(biased_exponent);
// Only single thread writes the computed scaling factor
// Also assuming one iteration covers exactly 32 rows
if ((tid_Y == 0) && !out_of_bounds) {
const int global_scales_offset_Y = iteration_scale_colwise_offset_Y;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
}
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
out_act_colwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal * after_dact_reg[stage]);
}
} // endif USE_COLWISE_SCALING
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence)
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
// dGeLU
if constexpr (USE_ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_rowwise_sh_curr));
if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_rowwise_sh_curr));
}
}
// dGeLU
if constexpr (USE_COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_act_colwise_sh_curr));
if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y,
reinterpret_cast<uint64_t *>(out_gate_colwise_sh_curr));
}
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
// Destroy the barriers. This invalidates the memory region of the barrier.
// If further computations were to take place in the kernel, this allows the
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) {
if (output->has_data()) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
}
NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function.");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block_dim(THREADS_PER_CHUNK);
const dim3 grid_dim(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act{};
alignas(64) CUtensorMap tensor_map_output_gate{};
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, sizeof(IType));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, sizeof(IType));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, sizeof(IType));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
// const size_t mbar_mem = ITERATIONS * sizeof(uint64_t);
const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem); // + mbar_mem;
cudaFuncSetAttribute(
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>
<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) {
const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data();
if (USE_ROWWISE_SCALING) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
}
if (USE_COLWISE_SCALING) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
}
// TODO: Make more general
const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1;
const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1;
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1;
size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
e8m0_t *const scales_rowwise_ptr =
USE_ROWWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->scale_inv.dptr) : nullptr;
e8m0_t *const scales_colwise_ptr =
USE_COLWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const dim3 block_dim(THREADS_PER_CHUNK);
const dim3 grid_dim(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_Y_colwise, SCALE_DIM_Y,
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_X_rowwise, SCALE_DIM_X,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act_rowwise{};
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{};
alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType));
if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0,
sizeof(OType));
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols,
sizeof(OType));
}
if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
0, sizeof(OType));
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
cols, sizeof(OType));
}
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
// const size_t mbar_mem = ITERATIONS * sizeof(uint64_t);
// const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem;
const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem;
cudaFuncSetAttribute(
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X>
<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), output->data.shape[0],
output->data.shape[1], {}, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output");
NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(),
"Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(),
", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2,
"Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2,
"], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match. Input shape: ", input.data.shape,
", output shape: ", output->data.shape, ".");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->dtype(), OType,
if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, ParamOP, ActOP, DActOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), {}, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) {
checkCuDriverContext(stream);
constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input");
CheckOutputTensor(*output, "output", allow_empty);
NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even.");
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
if constexpr (IS_DGATED) {
CheckInputTensor(grad, "grad");
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision.");
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input.");
NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input.");
}
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
"Either rowwise or columnwise output data need to be allocated.");
bool is_fp8_rowwise_output = true;
bool is_fp8_colwise_output = true;
if (output->has_data()) {
is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype);
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output.");
NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output.");
}
if (output->has_columnwise_data()) {
is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype);
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output.");
NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output.");
}
const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0;
if (is_delayed_tensor_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
} else {
if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
} else {
cast_gated<ParamOP, ActOP>(gated_input, output, stream);
}
}
} else if (is_mxfp_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
} else {
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
"by 32, got input of shape ", gated_input.data.shape);
}
} else {
NVTE_ERROR("Not supported scaling mode");
}
}
} // namespace gated_kernels
namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor =
IS_DGATED ? *(reinterpret_cast<const Tensor *>(grad)) : grad_empty_tensor;
const Tensor gated_input_tensor = *reinterpret_cast<const Tensor *>(gated_input);
Tensor *output_tensor = reinterpret_cast<Tensor *>(output);
if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
output_tensor, stream);
} else {
if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) {
if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, stream);
} else {
cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, stream);
}
} else {
// MX scaling
NVTE_ERROR("Not supported by the Arch < 10.0");
}
}
}
} // namespace detail
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file cast_kernels.cuh
* \brief CUDA kernels to cast to/from FP8/MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_
#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
#include <cfloat>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
constexpr size_t MXFP8_CHUNK_DIM_Y = 64;
constexpr size_t MXFP8_CHUNK_DIM_X = 64;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X;
constexpr size_t MXFP8_THREADS_PER_CHUNK = 64;
constexpr size_t MXFP8_BUFFERS_NUM = 2;
constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1;
static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM);
constexpr size_t ELEMS_PER_THREAD = 16;
constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported
constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64
constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32
constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64
constexpr size_t THREADS_PER_CHUNK_X_ROWWISE =
MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16
constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE =
MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4
constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64
constexpr size_t MXFP8_BUFF_STAGES_NUM =
MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16
constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32
static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM);
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, size_t SCALE_DIM_Y,
size_t SCALE_DIM_X>
__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const float *noop, float *const dbias_workspace, float *const amax_ptr,
const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
if (noop != nullptr && noop[0] == 1.0f) return;
}
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1
constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y =
SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1
constexpr size_t SCALES_ROWWISE_PER_BLOCK_X =
SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1
constexpr size_t SCALES_COLWISE_PER_BLOCK_Y =
SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1
constexpr size_t SCALES_COLWISE_PER_BLOCK_X =
SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X;
const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y;
const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X;
const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y;
const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X;
const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE;
const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE;
// const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE;
const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE;
const int thread_offset_Y = tid_rowwise_Y;
const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD;
// const int thread_offset_X_colwise = tid_colwise_X;
const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y;
const int dbias_rowwise_block_offset_X =
blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise;
const int dbias_colwise_offset_Y = blockIdx.y;
const int dbias_colwise_block_offset_X =
blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X;
const int dbias_stride = cols;
Vec<float, ELEMS_PER_THREAD> partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X];
float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X];
if constexpr (IS_DBIAS) {
if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
#pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) {
partial_dbias_rowwise[i].clear();
}
} else {
#pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) {
partial_dbias_colwise[i] = 0;
}
}
}
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned
__shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128)
OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128)
OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1);
const bool is_master_thread = (threadIdx.x == 0);
float block_amax = 0;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS];
initialize_barriers<MXFP8_ITERATIONS, MXFP8_THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
#pragma unroll
for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) {
const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X;
const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X;
const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y;
const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X;
const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X;
const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X;
const int scales_rowwise_chunk_offset_Y =
scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y;
const int scales_rowwise_chunk_offset_X =
scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X;
const int scales_colwise_chunk_offset_Y =
scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y;
const int scales_colwise_chunk_offset_X =
scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X;
#pragma unroll
for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) {
const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y;
const int chunk_stage_offset_X = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size,
&mbar[prefetch_buff], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff],
is_master_thread);
}
}
#pragma unroll
for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) {
const int buff = iter % MXFP8_BUFFERS_NUM;
const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
if (next_iter < MXFP8_ITERATIONS) {
const int next_buff = next_iter % MXFP8_BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size,
&mbar[next_iter], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in;
Vec<IType, ELEMS_PER_THREAD> act_in;
Vec<OType, ELEMS_PER_THREAD> out_c;
const int iteration_scale_rowwise_offset_Y =
scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
#pragma unroll
for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X_rowwise;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = (row >= rows);
in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]);
if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]);
}
float thread_amax = 0;
float in_compute[ELEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) {
const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
float elt = static_cast<float>(in.data.elt[j]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in.data.elt[j]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
if (!out_of_bounds) {
partial_dbias_rowwise[chunk_X].data.elt[j] += elt;
}
}
in_compute[j] = elt;
if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
}
__builtin_assume(block_amax >= 0);
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
const float subwarp_amax = subwarp_reduce_max_broadcast<SUBWARP_WIDTH>(thread_amax);
const e8m0_t biased_exponent =
float_to_e8m0(subwarp_amax * Quantized_Limits<OType>::max_norm_rcp);
// Only single thread writes the computed scaling factor
if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y;
const int global_scales_offset_X =
scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = exp2f_rcp(biased_exponent);
#pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) {
out_c.data.elt[j] = static_cast<OType>(in_compute[j] * block_scale_inverse);
}
out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]);
}
}
if constexpr (USE_COLWISE_SCALING) {
const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols);
float in_compute[SCALE_DIM_Y];
float amax = 0;
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const size_t row = row_base + i;
const bool row_out_of_bounds = (row >= rows);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
float elt = static_cast<float>(in_sh[buff][i][tid_colwise_X]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[buff][i][tid_colwise_X]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
if (!out_of_bounds) {
partial_dbias_colwise[chunk_X] += elt;
}
}
in_compute[i] = elt;
if (!out_of_bounds) {
amax = fmaxf(amax, fabsf(elt));
}
}
__builtin_assume(block_amax >= 0);
__builtin_assume(amax >= 0);
block_amax = fmaxf(block_amax, amax);
const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = exp2f_rcp(biased_exponent);
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
out_colwise_sh[buff][i][tid_colwise_X] =
static_cast<OType>(in_compute[i] * block_scale_inverse);
}
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (USE_ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff]));
}
if constexpr (USE_COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<MXFP8_PREFETCH_BUFFERS_NUM>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
parity ^= 1;
}
if constexpr (IS_DBIAS) {
if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X;
constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1;
constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE;
__shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD];
if (tid_rowwise_Y > 0) {
#pragma unroll
for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) {
partial_dbias_rowwise[c].store_to(
&shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]);
}
}
__syncthreads();
if (tid_rowwise_Y == 0) {
#pragma unroll
for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) {
Vec<float, ELEMS_PER_THREAD> other_row_dbias;
const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X;
const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X;
const int left_bound = dbias_rowwise_offset_X;
const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1;
#pragma unroll
for (int i = 0; i < Y; ++i) {
other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]);
#pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) {
partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j];
}
}
// Vectorized store when all elements are inside the boundaries
if (right_bound < cols) {
partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]);
} else if (left_bound < cols && right_bound >= cols) {
// Element-by-element store when some elements cross the boundaries
const int in_bound_elts_count = cols - left_bound;
partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0,
in_bound_elts_count);
}
}
}
} else {
#pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) {
const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X;
const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X;
const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols);
if (!col_out_of_bounds) {
dbias_workspace[dbias_offset] = partial_dbias_colwise[i];
}
}
}
}
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
block_amax = reduce_max<MXFP8_THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id);
}
if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, block_amax);
}
destroy_barriers<MXFP8_ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128;
constexpr size_t FP8_THREADS_PER_CHUNK = 128;
constexpr size_t FP8_BUFFERS_NUM = 2;
constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1;
static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM);
constexpr size_t FP8_BUFFER_DIM_Y = 16;
constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128
constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16
constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128
constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16
constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16
static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM);
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType>
__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input,
const __grid_constant__ CUtensorMap tensor_map_output,
float *const dbias_workspace, float *const amax_ptr,
float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows,
const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X;
const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK;
const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK;
const int thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X;
const int dbias_offset_Y = blockIdx.y + tid_Y;
const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X;
const bool col_out_of_bounds = my_column >= cols;
const int dbias_stride = cols;
float partial_dbias = 0.f;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1);
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS];
initialize_barriers<FP8_ITERATIONS, FP8_THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0;
const int chunk_offset_Y = block_offset_Y;
const int chunk_offset_X = block_offset_X;
#pragma unroll
for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) {
const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y;
const int chunk_stage_offset_X = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size,
&mbar[prefetch_buff], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff],
is_master_thread);
}
}
#pragma unroll
for (int iter = 0; iter < FP8_ITERATIONS; ++iter) {
const int buff = iter % FP8_BUFFERS_NUM;
const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y;
if (next_iter < FP8_ITERATIONS) {
const int next_buff = next_iter % FP8_BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter],
is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread);
}
}
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
#pragma unroll
for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = row >= rows;
const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds;
float elt = static_cast<float>(in_sh[buff][shmem_offset_y][shmem_offset_x]);
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[buff][shmem_offset_y][shmem_offset_x]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
if constexpr (IS_DACT) {
if (!out_of_bounds) {
partial_dbias += elt;
}
} else {
// If no activation, elt is 0 so we can safely do this
partial_dbias += elt;
}
}
__builtin_assume(amax >= 0);
if (IS_DACT) {
if (!out_of_bounds) {
amax = fmaxf(amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
amax = fmaxf(amax, fabsf(elt));
}
out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast<OType>(elt * scale);
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff]));
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<FP8_PREFETCH_BUFFERS_NUM>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
parity ^= 1;
if constexpr (IS_DBIAS) {
const int dbias_offset_X = my_column;
const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X;
if (!col_out_of_bounds) {
dbias_workspace[dbias_offset] = partial_dbias;
}
}
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<FP8_THREADS_PER_CHUNK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
destroy_barriers<FP8_ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
constexpr size_t CHUNKS_PER_BLOCK = 128;
constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK;
constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK;
constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE;
constexpr size_t CHUNKS_PER_ITERATION = 32;
constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE;
constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION;
constexpr size_t SHMEM_BUFFERS = 2;
static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0);
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &), typename IType,
typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr,
float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int block_offset = blockIdx.x * ELEMS_PER_BLOCK;
const IType *input = input_ptr + block_offset;
OType *output = output_ptr + block_offset;
float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS];
initialize_barriers<ITERATIONS, THREADS_PER_BLOCK>(mbar, is_master_thread);
int parity = 0;
copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread);
#pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) {
const int buff = iter % SHMEM_BUFFERS;
const int it_offset = iter * SHMEM_DIM;
const int next_iter = iter + 1;
const int next_buff = next_iter % SHMEM_BUFFERS;
const int next_iter_offset = next_iter * SHMEM_DIM;
if (next_iter < ITERATIONS) {
copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN,
&(mbar[next_iter]), is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[iter], parity);
#pragma unroll
for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) {
const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x;
float elt = static_cast<float>(in_sh[buff][shmem_offset]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(elt));
out_sh[buff][shmem_offset] = static_cast<OType>(elt * scale);
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
ptx::cp_async_bulk_tensor_1d_shared_to_global(
reinterpret_cast<uint64_t *>(output + it_offset),
reinterpret_cast<uint64_t *>(&out_sh[buff]), transaction_size_OUT);
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<1>();
}
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
amax = reduce_max<THREADS_PER_BLOCK / THREADS_PER_WARP>(amax, warp_id);
// Update the global amax
if (is_master_thread) {
atomicMaxFloat(amax_ptr, amax);
}
}
// Update scale-inverse
if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) {
reciprocal<float>(scale_inv_ptr, scale);
}
destroy_barriers<ITERATIONS>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
constexpr size_t DBIAS_THREADS_PER_BLOCK = 256;
template <int nvec, typename OType>
__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows,
const int cols) {
using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= cols) {
return;
}
const float *const thread_in_base = dbias_partial + thread_id * nvec;
OType *const thread_out_base = dbias_output + thread_id * nvec;
ComputeVec ldg_vec;
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < rows; ++i) {
ldg_vec.load_from(thread_in_base + i * cols);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}
OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = static_cast<OType>(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base);
}
template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, IType>
<<<reduce_dbias_num_blocks, DBIAS_THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
const size_t N = product(input.data.shape);
const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
NVTE_CHECK(isFullTile, "Only full tiles are supported.");
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
const size_t chunks = DIVUP(N, CHUNK_SIZE);
const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK);
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
const float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block(THREADS_PER_BLOCK);
const dim3 grid(blocks);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
const IType *input_ptr = reinterpret_cast<const IType *>(input.data.dptr);
OType *output_ptr = reinterpret_cast<OType *>(output->data.dptr);
cast_fp8_1D_kernel<IS_ACT, ParamOP, OP, IType, OType><<<grid, block, 0, stream>>>(
input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias,
Tensor *workspace, cudaStream_t stream) {
checkCuDriverContext(stream);
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X);
const size_t blocks_Y = chunks_Y;
const size_t blocks_X = chunks_X;
const size_t dbias_rows = blocks_Y;
const size_t dbias_cols = cols;
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {dbias_rows, dbias_cols};
workspace->data.dtype = DType::kFloat32;
return;
}
}
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const scale_inv_ptr = reinterpret_cast<float *>(output->scale_inv.dptr);
float *const scale_ptr = reinterpret_cast<float *>(output->scale.dptr);
const dim3 block(FP8_THREADS_PER_CHUNK);
const dim3 grid(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->data.dtype, OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
}
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(OType));
cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);
if constexpr (IS_DBIAS) {
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const Tensor *noop, // TODO (ksivamani)
Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
const auto &input_shape = input.data.shape;
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
if (use_rowwise_scaling) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
}
if (use_colwise_scaling) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Columnwise scaling tensor must be allocated");
}
CheckNoopTensor(*noop, "cast_noop");
// TODO: Make more general
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1;
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X);
const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y);
const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X);
const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1;
const size_t scale_stride_colwise =
use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;
e8m0_t *const scales_rowwise_ptr =
use_rowwise_scaling ? reinterpret_cast<e8m0_t *>(output->scale_inv.dptr) : nullptr;
e8m0_t *const scales_colwise_ptr =
use_colwise_scaling ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const size_t dbias_rows = blocks_Y;
const size_t dbias_cols = cols;
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {dbias_rows, dbias_cols};
workspace->data.dtype = DType::kFloat32;
return;
}
}
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const dim3 block(MXFP8_THREADS_PER_CHUNK);
const dim3 grid(blocks_X, blocks_Y);
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_Y_colwise, SCALE_DIM_Y,
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_X_rowwise, SCALE_DIM_X,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(IType));
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType));
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows,
cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType));
}
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X><<<grid, block, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr,
reinterpret_cast<const float *>(noop->data.dptr), workspace_ptr, amax_ptr,
rows, cols, scale_stride_rowwise, scale_stride_colwise);
if constexpr (IS_DBIAS) {
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
}
namespace detail {
using Empty = transformer_engine::Empty;
__device__ inline float identity(float value, const Empty &) { return value; }
struct DequantizeParam {
const float *scale_inv;
};
__device__ inline float dequantize_func(float value, const DequantizeParam &param) {
return value * (*(param.scale_inv));
}
} // namespace detail
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output,
cudaStream_t stream) {
constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP;
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<const fp32 *>(noop->data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {}, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output,
cudaStream_t stream) {
constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP;
const size_t N = product(input->data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input->data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, ParamOP, UnaryOP>(
reinterpret_cast<const IType *>(grad.data.dptr),
reinterpret_cast<const IType *>(input->data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {}, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
); // NOLINT(*)
}
namespace {
static bool is_full_tile_1D_tensor(const Tensor *const t) {
const size_t N = product(t->data.shape);
const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
return isFullTile;
}
bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr int TMA_bytes = 16;
const int alignment_requirement = TMA_bytes / typeToSize(t->dtype());
return cols % alignment_requirement == 0;
}
} // namespace
// Supported by the Arch >= 10.0
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop,
Tensor *output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) {
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (!IS_DBIAS && !IS_DACT) {
if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype())) {
// Aligned AND FP8
cast_fp8_1D<IS_ACT, ParamOP, OP>(input, output, stream);
} else {
// Unaligned
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
}
} else if (!IS_DBIAS && IS_DACT) {
if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype())) {
// Aligned AND FP8 (+dAct)
cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream);
} else {
// Unaligned
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
}
} else {
cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, dbias,
workspace, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
// Supported by the Arch < 10.0
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop,
Tensor *output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) {
if (!is_delayed_tensor_scaling(output->scaling_mode) || IS_DBIAS) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) +
" on GPU with compute capability < 10.0.");
}
if (!IS_DACT) {
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
} else {
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
}
}
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output,
Tensor *dbias, Tensor *workspace, cudaStream_t stream) {
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias != nullptr);
CheckOutputTensor(*dbias, "dbias");
}
if constexpr (IS_DACT) {
NVTE_CHECK(act_input != nullptr);
CheckInputTensor(*act_input, "activation_input");
NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match.");
}
NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
// Supported by the Arch >= 10.0
if (is_supported_by_CC_100()) {
fp8_quantize_arch_ge_100<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
dbias, workspace, stream);
} else {
// Supported by the Arch < 10.0
fp8_quantize_arch_l_100<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
dbias, workspace, stream);
}
}
namespace detail {
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
const Tensor *input_tensor;
const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) {
// backward - input is incoming gradient
input_tensor = reinterpret_cast<const Tensor *>(grad);
activation_input_tensor = reinterpret_cast<const Tensor *>(input);
} else {
// forward = input is activation input
input_tensor = reinterpret_cast<const Tensor *>(input);
activation_input_tensor = nullptr;
}
auto output_tensor = reinterpret_cast<Tensor *>(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace);
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor();
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) {
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
cast_transpose(*input_tensor, noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, float, ParamOP, OP>(
*input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
}
} else if (output_tensor->has_data()) {
fp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
}
} // namespace detail
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment