Unverified Commit 1d848f22 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

New fp8_transpose_dbias kernel (#73)



* Initial commit for fp8_transpose_dbias kernel
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* lint fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Suggestions and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e4a84a8d
......@@ -6,6 +6,7 @@ add_library(transformer_engine SHARED
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
gemm/cublaslt_gemm.cu
......
......@@ -68,6 +68,28 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension.
*
* This function takes FP8 input and produces 2 results:
* - `transposed_output` is the transposed result of the input.
* - `dbias` is the result of the reduction of the input along the first dimension.
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in,out] transposed_output Result of the transpose. Shape: [H, N].
* \param[out] dbias Result of the reduction of the input along the
* first dimension. Shape: [H].
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_transpose_dbias(const NVTETensor input,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute backward of GELU operation on the input, then cast and transpose. Additionally,
* reduce the result of the GELU backward along the first dimension.
*
......
This diff is collapsed.
......@@ -211,6 +211,24 @@ def fp8_cast_transpose_bgrad_fused(
)
def fp8_transpose_bgrad_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
grad_bias_type: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transpose + BGRAD with FP8 output"""
return tex.fused_fp8_transpose_bgrad(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
TE_DType[grad_bias_type],
)
def fp8_cast_transpose_bgrad_dgelu_fused(
grad_output: torch.Tensor,
gelu_input: torch.Tensor,
......
......@@ -150,6 +150,49 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
}
std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
transformer_engine::DType grad_bias_type
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_transpose};
}
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor gelu_input,
at::Tensor scale,
......@@ -852,6 +895,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad,
"Fused FP8 Transpose + BGRAD");
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU");
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
......
......@@ -48,6 +48,15 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
);
std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
transformer_engine::DType grad_bias_type
);
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor gelu_input,
at::Tensor scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment