Unverified Commit 96ad903c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add FP8 support for Ada (#129)



* Add FP8 support for Ada
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* better message
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* better message for no fp8
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* same thing for onnx test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI and review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 66c10f7a
......@@ -33,6 +33,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import is_fp8_available
# Global test configuration knobs.
......@@ -57,10 +58,8 @@ assert OPSET >= TRILU_OPSET
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
skip_FP8 = pytest.mark.skipif(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
reason="Device compute capability 9.x required for FP8 execution.",
)
fp8_available, reason_for_no_fp8 = is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
......@@ -376,8 +375,8 @@ def test_export_gemm(
scale_factors
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors):
......@@ -497,8 +496,8 @@ def test_export_layernorm(
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary).
inp_shape = [64, 32]
......@@ -638,8 +637,8 @@ def test_export_linear(
precision: torch.dtype
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary).
in_features = 64
......@@ -715,8 +714,8 @@ def test_export_layernorm_linear(
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary).
in_features = 64
......@@ -770,8 +769,8 @@ def test_export_layernorm_mlp(
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary).
in_features = 64
......@@ -890,8 +889,8 @@ def test_export_multihead_attention(
fuse_qkv_params: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
hidden_size = 256
sequence_length = 128
......@@ -967,8 +966,8 @@ def test_export_transformer_layer(
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Layer configuration
hidden_size = 64
......
......@@ -5,7 +5,7 @@
import torch
import pytest
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_autocast, is_fp8_available
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
......@@ -19,7 +19,7 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe
# Only run FP8 tests on H100.
fp8_available = torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9
fp8_available, reason_for_no_fp8 = is_fp8_available()
def custom_amax_to_scale(
......@@ -263,7 +263,7 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -291,7 +291,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -316,7 +316,7 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -347,7 +347,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -385,7 +385,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -423,7 +423,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -461,7 +461,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -495,7 +495,7 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -532,7 +532,7 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......@@ -570,7 +570,7 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.")
pytest.skip(reason_for_no_fp8)
config = model_configs[model]
......
......@@ -5,7 +5,7 @@
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
......
......@@ -23,6 +23,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cublasLt.h>
#include <stdexcept>
#include <memory>
#include <iomanip>
......
......@@ -873,6 +873,11 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
}
size_t get_cublasLt_version() {
return cublasLtGetVersion();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
......@@ -907,6 +912,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output");
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
// Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
......
......@@ -12,6 +12,7 @@ import transformer_engine_extensions as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
from .utils import get_device_compute_capability
_FP8_ENABLED = False
_FP8_CALIBRATION = False
......@@ -26,6 +27,29 @@ _fp8_tensors_recompute_buffer = []
_amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None
_is_fp8_available = None
_reason_for_no_fp8 = ""
def _check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= 9.0: # hopper and above
return True, ""
if get_device_compute_capability() < 8.9: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def is_fp8_available() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available, _reason_for_no_fp8 = _check_fp8_support()
return _is_fp8_available, _reason_for_no_fp8
def get_meta_tensor_key(forward: bool = True) -> str:
......@@ -253,9 +277,8 @@ def fp8_autocast(
_FP8_AUTOCAST_DEPTH += 1
if enabled:
assert (
torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9
), "Device compute capability 9.x required for FP8 execution."
fp8_available, reason_for_no_fp8 = is_fp8_available()
assert fp8_available, reason_for_no_fp8
yield
finally:
_FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
......@@ -290,10 +313,12 @@ def is_fp8_enabled() -> bool:
"""Is FP8 enabled"""
return _FP8_ENABLED
def is_fp8_calibration() -> bool:
"""Is FP8 calibration"""
return _FP8_CALIBRATION
def is_first_fp8_module():
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment