Unverified Commit 4dc36f0e authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Workaround for incorrect output from torch.cuda.is_bf16_compatible()...


[PyTorch] Workaround for incorrect output from torch.cuda.is_bf16_compatible() on V100s and TU102s (#626)

* replaced torch.cuda.is_bf16_compatible() with explicit sm_80 check via torch.cuda.get_device_capability()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* implement te.utils.is_bf16_compatible() to replace torch.cuda counterpart
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent f5412e5f
......@@ -41,6 +41,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
import transformer_engine_extensions as tex
from transformer_engine_extensions import NVTE_Fused_Attn_Backend
......@@ -194,7 +195,7 @@ model_configs_base = {
}
param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
......
......@@ -17,6 +17,7 @@ from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
attention_mask_func,
is_bf16_compatible,
)
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
......@@ -53,7 +54,7 @@ model_configs = {
}
param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
batch_sizes = [1, 2]
......
......@@ -13,6 +13,7 @@ from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.utils import (
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from transformer_engine.pytorch import (
LayerNormLinear,
......@@ -101,7 +102,7 @@ fp8_recipes = [
]
param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported():
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
all_boolean = [True, False]
......
......@@ -222,3 +222,9 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None:
"Tensor dimensions are not compatible for FP8 execution: "
f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)"
)
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
return torch.cuda.get_device_capability()[0] >= 8
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment