Unverified Commit 4e036c8c authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Move swizzle scaling factor to cpp (#1683)



* move swizzle scaling factor to cpp
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e61ce77c
......@@ -268,7 +268,7 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
size_t nvte_tensor_element_size(const NVTETensor tensor) {
if (tensor == nullptr) return sizeof(float);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return transformer_engine::typeToSize(t.data.dtype);
return transformer_engine::typeToSize(t.dtype());
}
void *nvte_tensor_data(const NVTETensor tensor) {
......
......@@ -12,7 +12,6 @@ from ..constants import TE_DType
from ..utils import get_sm_count
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
......@@ -28,46 +27,6 @@ def _empty_tensor() -> torch.Tensor:
return torch.Tensor().cuda()
def swizzle_inputs(A: torch.Tensor, B: torch.Tensor, layout: str):
"""Swizzle gemm inputs and return original scaling factor inverses."""
if not isinstance(A, MXFP8TensorBase) or not isinstance(B, MXFP8TensorBase):
return None
original_scale_inverses = (
A._rowwise_scale_inv,
A._columnwise_scale_inv,
B._rowwise_scale_inv,
B._columnwise_scale_inv,
)
if layout[0] == "T":
A._rowwise_scale_inv = tex.rowwise_swizzle(A._rowwise_data, A._rowwise_scale_inv)
else:
A._columnwise_scale_inv = tex.columnwise_swizzle(
A._columnwise_data, A._columnwise_scale_inv
)
if layout[1] == "N":
B._rowwise_scale_inv = tex.rowwise_swizzle(B._rowwise_data, B._rowwise_scale_inv)
else:
B._columnwise_scale_inv = tex.columnwise_swizzle(
B._columnwise_data, B._columnwise_scale_inv
)
return original_scale_inverses
def reset_swizzled_inputs(A, B, scale_inverses):
"""Reset the swizzled scale inverses after GEMM."""
if scale_inverses is not None:
(
A._rowwise_scale_inv,
A._columnwise_scale_inv,
B._rowwise_scale_inv,
B._columnwise_scale_inv,
) = scale_inverses
def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
......@@ -149,9 +108,7 @@ def general_gemm(
"bulk_overlap": bulk_overlap,
}
original_scale_inverses = swizzle_inputs(A, B, layout)
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
reset_swizzled_inputs(A, B, original_scale_inverses)
if debug_quantizer is not None:
out = debug_quantizer.process_gemm_output(out)
......@@ -210,8 +167,6 @@ def general_grouped_gemm(
for o in out
] # this should differ with respect to single output
# TODO: Move the swizzle to the C++ side. # pylint: disable=fixme
original_scale_inverses_list = [swizzle_inputs(A[i], B[i], layout) for i in range(num_gemms)]
bias = tex.te_general_grouped_gemm(
A,
transa,
......@@ -231,7 +186,5 @@ def general_grouped_gemm(
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
for i in range(num_gemms):
reset_swizzled_inputs(A[i], B[i], original_scale_inverses_list[i])
return out, bias, gelu_input
......@@ -50,11 +50,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
......@@ -63,8 +63,8 @@ std::vector<py::object> fused_attn_bwd(
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer);
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
......@@ -270,12 +270,12 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
/***************************************************************************************************
......@@ -396,8 +396,6 @@ void nvshmem_finalize();
* swizzle
**************************************************************************************************/
void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans);
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv);
......
......@@ -8,7 +8,7 @@
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch;
......@@ -96,7 +96,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
......
......@@ -92,11 +92,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TensorWrapper te_Q, te_K, te_V, te_O, te_S;
......@@ -282,8 +282,8 @@ std::vector<py::object> fused_attn_bwd(
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
......
......@@ -17,6 +17,7 @@
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
namespace {
......@@ -175,8 +176,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
if (comm_overlap) {
// Prepare extra output tensor
TensorWrapper extra_output_tensor;
......@@ -313,6 +321,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers;
std::vector<at::Tensor> D_vectors;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto none = py::none();
......@@ -379,6 +389,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue;
}
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa)));
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb)));
auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
......
......@@ -6,14 +6,16 @@
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) {
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) {
return;
return std::nullopt;
}
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
......@@ -48,9 +50,9 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
} else {
input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
input_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0,
scale_inv_shape);
}
......@@ -63,6 +65,8 @@ void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool roww
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
}
return swizzled_scale_inv;
}
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
......
......@@ -7,6 +7,19 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#include <torch/extension.h>
#include <optional>
#include "transformer_engine/transformer_engine.h"
bool non_tn_fp8_gemm_supported();
/* Swizzle the scaling factor of the input tensor.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
bool trans);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment