Unverified Commit c7702309 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

rtx5090 arch fix support (#1659)



* rtx5090 arch fix support
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* apprend `nvte` to the function name so that its visible in framework specific dirs
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix typo
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add filter for nvte_is_supported_nontn_fp8_gemm
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* properly expose the api
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feedback from PR
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move the function to apt header/c files
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more info
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 91405eb4
......@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
......@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor:
"""Check numerical error when casting to FP8"""
# Skip invalid configurations
if non_tn_fp8_gemm_supported() and return_transpose:
if is_non_tn_fp8_gemm_supported() and return_transpose:
pytest.skip("FP8 transpose is neither needed nor supported on current system")
# Initialize random high precision data
......
......@@ -92,9 +92,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret;
// Device compute capability
const int arch = cuda::sm_arch();
// Transpose mode with column-major ordering
bool is_A_transposed = transA == CUBLAS_OP_T;
bool is_B_transposed = transB == CUBLAS_OP_T;
......@@ -107,7 +104,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (arch < 100 && !is_A_transposed) {
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
......@@ -166,7 +163,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (arch < 100 && is_B_transposed) {
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
......
......@@ -332,6 +332,12 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
*/
void nvte_destroy_quantization_config(NVTEQuantizationConfig config);
/*! \brief Check if non-TN FP8 Gemm is supported.
*
* \return A flag which indicates whether non-TN FP8 Gemm is supported or not.
*/
int nvte_is_non_tn_fp8_gemm_supported();
#ifdef __cplusplus
} // extern "C"
......
......@@ -10,6 +10,7 @@
#include <iostream>
#include "common.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
......@@ -474,3 +475,13 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
}
}
int nvte_is_non_tn_fp8_gemm_supported() {
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
}
......@@ -9,7 +9,6 @@
#include "common.h"
#include "pybind.h"
#include "torch/torch.h"
#include "util.h"
namespace transformer_engine::pytorch {
......@@ -103,7 +102,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported();
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
}
......@@ -215,7 +214,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported();
bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
}
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "util.h"
#include "ATen/cuda/CUDAContextLight.h"
bool non_tn_fp8_gemm_supported() {
int major = at::cuda::getCurrentDeviceProperties()->major;
return major >= 10;
}
......@@ -13,8 +13,6 @@
#include "transformer_engine/transformer_engine.h"
bool non_tn_fp8_gemm_supported();
/* Swizzle the scaling factor of the input tensor.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
......
......@@ -19,7 +19,11 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm
from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
needs_quantized_gemm,
)
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
......@@ -938,7 +942,7 @@ def _all_gather_fp8(
# Make sure FP8 transpose is populated if needed
needs_transpose = (
quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported()
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
)
if needs_transpose:
if handle is not None:
......
......@@ -42,7 +42,7 @@ from ..utils import (
assert_dim_for_fp8_exec,
clear_tensor_data,
requires_grad,
non_tn_fp8_gemm_supported,
is_non_tn_fp8_gemm_supported,
needs_quantized_gemm,
)
from ..distributed import (
......@@ -1006,7 +1006,7 @@ class _LayerNormMLP(torch.autograd.Function):
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
elif not non_tn_fp8_gemm_supported():
elif not is_non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
......
......@@ -32,7 +32,7 @@ from ..utils import (
init_method_constant,
requires_grad,
needs_quantized_gemm,
non_tn_fp8_gemm_supported,
is_non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec,
nvtx_range_pop,
nvtx_range_push,
......@@ -640,7 +640,7 @@ class _Linear(torch.autograd.Function):
inputmat_total = _fix_gathered_fp8_transpose(
inputmat_total, ctx.tp_size
)
elif not non_tn_fp8_gemm_supported():
elif not is_non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total._create_transpose()
......
......@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..utils import canonicalize_process_group, devices_match, non_tn_fp8_gemm_supported
from ..utils import canonicalize_process_group, devices_match, is_non_tn_fp8_gemm_supported
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type
......@@ -432,7 +432,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
has_data_transpose = self._transpose is not None and not self._transpose_invalid
needs_data = has_data
needs_data_transpose = has_data_transpose
if non_tn_fp8_gemm_supported():
if is_non_tn_fp8_gemm_supported():
if rowwise_usage is not None and rowwise_usage:
needs_data = True
if columnwise_usage is not None and columnwise_usage:
......
......@@ -251,11 +251,12 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8
def non_tn_fp8_gemm_supported() -> bool:
def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
"""
return torch.cuda.get_device_capability() >= (10, 0)
device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
@functools.lru_cache(maxsize=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment