Unverified Commit 4dc36f0e authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Workaround for incorrect output from torch.cuda.is_bf16_compatible()...


[PyTorch] Workaround for incorrect output from torch.cuda.is_bf16_compatible() on V100s and TU102s (#626)

* replaced torch.cuda.is_bf16_compatible() with explicit sm_80 check via torch.cuda.get_device_capability()
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* implement te.utils.is_bf16_compatible() to replace torch.cuda counterpart
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent f5412e5f
...@@ -41,6 +41,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -41,6 +41,7 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible,
) )
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine_extensions import NVTE_Fused_Attn_Backend from transformer_engine_extensions import NVTE_Fused_Attn_Backend
...@@ -194,7 +195,7 @@ model_configs_base = { ...@@ -194,7 +195,7 @@ model_configs_base = {
} }
param_types = [torch.float16] param_types = [torch.float16]
if torch.cuda.is_bf16_supported(): if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16] param_types_lean = [torch.bfloat16]
......
...@@ -17,6 +17,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -17,6 +17,7 @@ from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
attention_mask_func, attention_mask_func,
is_bf16_compatible,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
...@@ -53,7 +54,7 @@ model_configs = { ...@@ -53,7 +54,7 @@ model_configs = {
} }
param_types = [torch.float32, torch.float16] param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported(): if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
batch_sizes = [1, 2] batch_sizes = [1, 2]
......
...@@ -13,6 +13,7 @@ from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager ...@@ -13,6 +13,7 @@ from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
init_method_normal, init_method_normal,
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
LayerNormLinear, LayerNormLinear,
...@@ -101,7 +102,7 @@ fp8_recipes = [ ...@@ -101,7 +102,7 @@ fp8_recipes = [
] ]
param_types = [torch.float32, torch.float16] param_types = [torch.float32, torch.float16]
if torch.cuda.is_bf16_supported(): if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16) param_types.append(torch.bfloat16)
all_boolean = [True, False] all_boolean = [True, False]
......
...@@ -222,3 +222,9 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: ...@@ -222,3 +222,9 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None:
"Tensor dimensions are not compatible for FP8 execution: " "Tensor dimensions are not compatible for FP8 execution: "
f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)" f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)"
) )
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
return torch.cuda.get_device_capability()[0] >= 8
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment