Unverified Commit 96ad903c authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add FP8 support for Ada (#129)



* Add FP8 support for Ada
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* better message
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* better message for no fp8
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* same thing for onnx test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI and review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 66c10f7a
...@@ -33,6 +33,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp ...@@ -33,6 +33,7 @@ import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.fp8 import is_fp8_available
# Global test configuration knobs. # Global test configuration knobs.
...@@ -57,10 +58,8 @@ assert OPSET >= TRILU_OPSET ...@@ -57,10 +58,8 @@ assert OPSET >= TRILU_OPSET
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). # Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so") ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
skip_FP8 = pytest.mark.skipif( fp8_available, reason_for_no_fp8 = is_fp8_available()
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
reason="Device compute capability 9.x required for FP8 execution.",
)
def create_fp8_recipe(): def create_fp8_recipe():
return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
...@@ -376,8 +375,8 @@ def test_export_gemm( ...@@ -376,8 +375,8 @@ def test_export_gemm(
scale_factors scale_factors
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and not fp8_available:
pytest.skip("Device compute capability 9.x required for FP8 execution.") pytest.skip(reason_for_no_fp8)
class TestFP8_GEMM(nn.Module): class TestFP8_GEMM(nn.Module):
def __init__(self, precision, use_bias, gelu, scale_factors): def __init__(self, precision, use_bias, gelu, scale_factors):
...@@ -497,8 +496,8 @@ def test_export_layernorm( ...@@ -497,8 +496,8 @@ def test_export_layernorm(
zero_centered_gamma: bool zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and not fp8_available:
pytest.skip("Device compute capability 9.x required for FP8 execution.") pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
inp_shape = [64, 32] inp_shape = [64, 32]
...@@ -638,8 +637,8 @@ def test_export_linear( ...@@ -638,8 +637,8 @@ def test_export_linear(
precision: torch.dtype precision: torch.dtype
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and not fp8_available:
pytest.skip("Device compute capability 9.x required for FP8 execution.") pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
...@@ -715,8 +714,8 @@ def test_export_layernorm_linear( ...@@ -715,8 +714,8 @@ def test_export_layernorm_linear(
zero_centered_gamma: bool zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and not fp8_available:
pytest.skip("Device compute capability 9.x required for FP8 execution.") pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
...@@ -770,8 +769,8 @@ def test_export_layernorm_mlp( ...@@ -770,8 +769,8 @@ def test_export_layernorm_mlp(
zero_centered_gamma: bool zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and not fp8_available:
pytest.skip("Device compute capability 9.x required for FP8 execution.") pytest.skip(reason_for_no_fp8)
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64 in_features = 64
...@@ -890,8 +889,8 @@ def test_export_multihead_attention( ...@@ -890,8 +889,8 @@ def test_export_multihead_attention(
fuse_qkv_params: bool fuse_qkv_params: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and not fp8_available:
pytest.skip("Device compute capability 9.x required for FP8 execution.") pytest.skip(reason_for_no_fp8)
hidden_size = 256 hidden_size = 256
sequence_length = 128 sequence_length = 128
...@@ -967,8 +966,8 @@ def test_export_transformer_layer( ...@@ -967,8 +966,8 @@ def test_export_transformer_layer(
zero_centered_gamma: bool zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
if use_fp8 and torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: if use_fp8 and not fp8_available:
pytest.skip("Device compute capability 9.x required for FP8 execution.") pytest.skip(reason_for_no_fp8)
# Layer configuration # Layer configuration
hidden_size = 64 hidden_size = 64
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
import pytest import pytest
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast, is_fp8_available
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
...@@ -19,7 +19,7 @@ from transformer_engine.pytorch import ( ...@@ -19,7 +19,7 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe from transformer_engine.common import recipe
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available = torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9 fp8_available, reason_for_no_fp8 = is_fp8_available()
def custom_amax_to_scale( def custom_amax_to_scale(
...@@ -263,7 +263,7 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -263,7 +263,7 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -291,7 +291,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_ ...@@ -291,7 +291,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -316,7 +316,7 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -316,7 +316,7 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -347,7 +347,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen ...@@ -347,7 +347,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -385,7 +385,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -385,7 +385,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -423,7 +423,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam ...@@ -423,7 +423,7 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gam
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -461,7 +461,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma ...@@ -461,7 +461,7 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -495,7 +495,7 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -495,7 +495,7 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -532,7 +532,7 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -532,7 +532,7 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad): def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
...@@ -570,7 +570,7 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -570,7 +570,7 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): def test_sanity_gradient_accumulation_fusion(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip("FP8 device not available.") pytest.skip(reason_for_no_fp8)
config = model_configs[model] config = model_configs[model]
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
cmake_minimum_required(VERSION 3.18) cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90) set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif() endif()
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cublasLt.h>
#include <stdexcept> #include <stdexcept>
#include <memory> #include <memory>
#include <iomanip> #include <iomanip>
......
...@@ -873,6 +873,11 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -873,6 +873,11 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
} }
size_t get_cublasLt_version() {
return cublasLtGetVersion();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions // Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD"); m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
...@@ -907,6 +912,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -907,6 +912,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output"); m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output");
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
// Data structures // Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta") py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>()) .def(py::init<>())
......
...@@ -12,6 +12,7 @@ import transformer_engine_extensions as tex ...@@ -12,6 +12,7 @@ import transformer_engine_extensions as tex
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type from .constants import dist_group_type
from .utils import get_device_compute_capability
_FP8_ENABLED = False _FP8_ENABLED = False
_FP8_CALIBRATION = False _FP8_CALIBRATION = False
...@@ -26,6 +27,29 @@ _fp8_tensors_recompute_buffer = [] ...@@ -26,6 +27,29 @@ _fp8_tensors_recompute_buffer = []
_amax_forward_global_reduce_func = None _amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None _buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None _buffer_delete_key_bwd = None
_is_fp8_available = None
_reason_for_no_fp8 = ""
def _check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= 9.0: # hopper and above
return True, ""
if get_device_compute_capability() < 8.9: # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def is_fp8_available() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available, _reason_for_no_fp8 = _check_fp8_support()
return _is_fp8_available, _reason_for_no_fp8
def get_meta_tensor_key(forward: bool = True) -> str: def get_meta_tensor_key(forward: bool = True) -> str:
...@@ -253,9 +277,8 @@ def fp8_autocast( ...@@ -253,9 +277,8 @@ def fp8_autocast(
_FP8_AUTOCAST_DEPTH += 1 _FP8_AUTOCAST_DEPTH += 1
if enabled: if enabled:
assert ( fp8_available, reason_for_no_fp8 = is_fp8_available()
torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9 assert fp8_available, reason_for_no_fp8
), "Device compute capability 9.x required for FP8 execution."
yield yield
finally: finally:
_FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state _FP8_ENABLED,_FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
...@@ -290,10 +313,12 @@ def is_fp8_enabled() -> bool: ...@@ -290,10 +313,12 @@ def is_fp8_enabled() -> bool:
"""Is FP8 enabled""" """Is FP8 enabled"""
return _FP8_ENABLED return _FP8_ENABLED
def is_fp8_calibration() -> bool: def is_fp8_calibration() -> bool:
"""Is FP8 calibration""" """Is FP8 calibration"""
return _FP8_CALIBRATION return _FP8_CALIBRATION
def is_first_fp8_module(): def is_first_fp8_module():
"""Returns `True` only the first time when called multiple """Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context. times from within the same `fp8_autocast` context.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment