/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #include "../common.h" namespace transformer_engine::detail { void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream); template void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream); template void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, cudaStream_t stream); void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, cudaStream_t stream); // enum class for rowwise usage enum class FP8BlockwiseRowwiseOption { // No rowwise data NONE, // Rowwise data, scales in GEMM format ROWWISE // TODO: FP8 all gather requires some changes. // 1. Compact scales are better for gathering than the GEMM format. }; // enum class for columnwise usage // For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling enum class FP8BlockwiseColumnwiseOption { // No columnwise data NONE, // Columnwise data transposed from original shape. // Scales in GEMM format corresponding to GEMM ingesting transposed column data. COLUMNWISE_TRANSPOSE // TODO: FP8 all gather requires some changes. // 1. The transpose gets in the way of the all gather. // 2. Compact scales are better for gathering than the GEMM format. }; void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, const bool pow_2_scale, cudaStream_t stream); } // namespace transformer_engine::detail #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_