Unverified Commit c7702309 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

rtx5090 arch fix support (#1659)



* rtx5090 arch fix support
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* apprend `nvte` to the function name so that its visible in framework specific dirs
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix typo
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add filter for nvte_is_supported_nontn_fp8_gemm
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* properly expose the api
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feedback from PR
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move the function to apt header/c files
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more info
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 91405eb4
...@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
...@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor: ...@@ -400,7 +400,7 @@ class TestCurrentScalingFloat8Tensor:
"""Check numerical error when casting to FP8""" """Check numerical error when casting to FP8"""
# Skip invalid configurations # Skip invalid configurations
if non_tn_fp8_gemm_supported() and return_transpose: if is_non_tn_fp8_gemm_supported() and return_transpose:
pytest.skip("FP8 transpose is neither needed nor supported on current system") pytest.skip("FP8 transpose is neither needed nor supported on current system")
# Initialize random high precision data # Initialize random high precision data
......
...@@ -92,9 +92,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -92,9 +92,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret; GemmParam ret;
// Device compute capability
const int arch = cuda::sm_arch();
// Transpose mode with column-major ordering // Transpose mode with column-major ordering
bool is_A_transposed = transA == CUBLAS_OP_T; bool is_A_transposed = transA == CUBLAS_OP_T;
bool is_B_transposed = transB == CUBLAS_OP_T; bool is_B_transposed = transB == CUBLAS_OP_T;
...@@ -107,7 +104,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -107,7 +104,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype; ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr; ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m; ret.lda = is_A_transposed ? k : m;
if (arch < 100 && !is_A_transposed) { if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr; ret.A = A.columnwise_data.dptr;
...@@ -166,7 +163,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -166,7 +163,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype; ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr; ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k; ret.ldb = is_B_transposed ? n : k;
if (arch < 100 && is_B_transposed) { if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr; ret.B = B.columnwise_data.dptr;
......
...@@ -332,6 +332,12 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -332,6 +332,12 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
*/ */
void nvte_destroy_quantization_config(NVTEQuantizationConfig config); void nvte_destroy_quantization_config(NVTEQuantizationConfig config);
/*! \brief Check if non-TN FP8 Gemm is supported.
*
* \return A flag which indicates whether non-TN FP8 Gemm is supported or not.
*/
int nvte_is_non_tn_fp8_gemm_supported();
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <iostream> #include <iostream>
#include "common.h" #include "common.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -474,3 +475,13 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { ...@@ -474,3 +475,13 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config); delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
} }
} }
int nvte_is_non_tn_fp8_gemm_supported() {
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
}
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include "common.h" #include "common.h"
#include "pybind.h" #include "pybind.h"
#include "torch/torch.h" #include "torch/torch.h"
#include "util.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
...@@ -103,7 +102,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor( ...@@ -103,7 +102,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
} }
const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data; at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) { if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts); columnwise_data = at::empty(columnwise_torch_shape, opts);
} }
...@@ -215,7 +214,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso ...@@ -215,7 +214,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
} }
const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data; at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (create_transpose) { if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts); columnwise_data = at::empty(columnwise_torch_shape, opts);
} }
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "util.h"
#include "ATen/cuda/CUDAContextLight.h"
bool non_tn_fp8_gemm_supported() {
int major = at::cuda::getCurrentDeviceProperties()->major;
return major >= 10;
}
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
bool non_tn_fp8_gemm_supported();
/* Swizzle the scaling factor of the input tensor. /* Swizzle the scaling factor of the input tensor.
* *
* The returned swizzled scaling factor tensor should be kept alive during the GEMM. * The returned swizzled scaling factor tensor should be kept alive during the GEMM.
......
...@@ -19,7 +19,11 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP ...@@ -19,7 +19,11 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
needs_quantized_gemm,
)
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
...@@ -938,7 +942,7 @@ def _all_gather_fp8( ...@@ -938,7 +942,7 @@ def _all_gather_fp8(
# Make sure FP8 transpose is populated if needed # Make sure FP8 transpose is populated if needed
needs_transpose = ( needs_transpose = (
quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported() quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
) )
if needs_transpose: if needs_transpose:
if handle is not None: if handle is not None:
......
...@@ -42,7 +42,7 @@ from ..utils import ( ...@@ -42,7 +42,7 @@ from ..utils import (
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
requires_grad, requires_grad,
non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
needs_quantized_gemm, needs_quantized_gemm,
) )
from ..distributed import ( from ..distributed import (
...@@ -1006,7 +1006,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1006,7 +1006,7 @@ class _LayerNormMLP(torch.autograd.Function):
# All-gather executed on columnwise data and result is in rowwise data, # All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD. # so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size) ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
elif not non_tn_fp8_gemm_supported(): elif not is_non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must # FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose. # have a valid transpose.
ln_out_total._create_transpose() ln_out_total._create_transpose()
......
...@@ -32,7 +32,7 @@ from ..utils import ( ...@@ -32,7 +32,7 @@ from ..utils import (
init_method_constant, init_method_constant,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
...@@ -640,7 +640,7 @@ class _Linear(torch.autograd.Function): ...@@ -640,7 +640,7 @@ class _Linear(torch.autograd.Function):
inputmat_total = _fix_gathered_fp8_transpose( inputmat_total = _fix_gathered_fp8_transpose(
inputmat_total, ctx.tp_size inputmat_total, ctx.tp_size
) )
elif not non_tn_fp8_gemm_supported(): elif not is_non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must # FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose. # have a valid transpose.
inputmat_total._create_transpose() inputmat_total._create_transpose()
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..utils import canonicalize_process_group, devices_match, non_tn_fp8_gemm_supported from ..utils import canonicalize_process_group, devices_match, is_non_tn_fp8_gemm_supported
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type from ..constants import dist_group_type
...@@ -432,7 +432,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -432,7 +432,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
has_data_transpose = self._transpose is not None and not self._transpose_invalid has_data_transpose = self._transpose is not None and not self._transpose_invalid
needs_data = has_data needs_data = has_data
needs_data_transpose = has_data_transpose needs_data_transpose = has_data_transpose
if non_tn_fp8_gemm_supported(): if is_non_tn_fp8_gemm_supported():
if rowwise_usage is not None and rowwise_usage: if rowwise_usage is not None and rowwise_usage:
needs_data = True needs_data = True
if columnwise_usage is not None and columnwise_usage: if columnwise_usage is not None and columnwise_usage:
......
...@@ -251,11 +251,12 @@ def is_bf16_compatible() -> None: ...@@ -251,11 +251,12 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8 return torch.cuda.get_device_capability()[0] >= 8
def non_tn_fp8_gemm_supported() -> bool: def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports """Checks whether the device supports
non-TN layouts for FP8 GEMMs. non-TN layouts for FP8 GEMMs.
""" """
return torch.cuda.get_device_capability() >= (10, 0) device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment