Unverified Commit 8c3110d1 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Catch cublas FP8 errors (#317)



* Better dimension assert for FP8
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* line
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a7bc7cf7
...@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union ...@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec
__all__ = ['gemm', 'fp8_gemm'] __all__ = ['gemm', 'fp8_gemm']
...@@ -41,6 +42,8 @@ def fp8_gemm( ...@@ -41,6 +42,8 @@ def fp8_gemm(
empty_tensor = torch.Tensor() empty_tensor = torch.Tensor()
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None assert fp8_meta_tensor is not None and out_index is not None
assert_dim_for_fp8_exec(A)
assert_dim_for_fp8_exec(B)
return_output = False return_output = False
if out is None: if out is None:
......
...@@ -27,7 +27,7 @@ from ..utils import ( ...@@ -27,7 +27,7 @@ from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_forward_exec, assert_dim_for_fp8_exec,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -95,8 +95,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -95,8 +95,8 @@ class _LayerNormLinear(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
if fp8: if fp8:
assert_dim_for_fp8_forward_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_forward_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
...@@ -30,7 +30,7 @@ from ..utils import ( ...@@ -30,7 +30,7 @@ from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_forward_exec, assert_dim_for_fp8_exec,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -112,9 +112,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -112,9 +112,9 @@ class _LayerNormMLP(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
if fp8: if fp8:
assert_dim_for_fp8_forward_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_forward_exec(fc1_weight) assert_dim_for_fp8_exec(fc1_weight)
assert_dim_for_fp8_forward_exec(fc2_weight) assert_dim_for_fp8_exec(fc2_weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
...@@ -24,7 +24,7 @@ from ..utils import ( ...@@ -24,7 +24,7 @@ from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_forward_exec, assert_dim_for_fp8_exec,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -81,8 +81,8 @@ class _Linear(torch.autograd.Function): ...@@ -81,8 +81,8 @@ class _Linear(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
if fp8: if fp8:
assert_dim_for_fp8_forward_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
assert_dim_for_fp8_forward_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
...@@ -179,19 +179,19 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: ...@@ -179,19 +179,19 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype) return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
def check_dim_for_fp8_forward_exec(tensor: torch.Tensor) -> bool: def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool:
"""For fp8 fprop (TN layout), inputs and weights must be such """For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16. that dim0 is divisible by 8 and dim1 is divisible by 16.
""" """
return not tensor.shape[0] % 8 and not tensor.shape[1] % 16 return not tensor.shape[0] % 8 and not tensor.shape[1] % 16
def assert_dim_for_fp8_forward_exec(tensor: torch.Tensor) -> None: def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None:
"""For fp8 fprop (TN layout), inputs and weights must be such """For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16. that dim0 is divisible by 8 and dim1 is divisible by 16.
""" """
# single tensor check so it's clear which tensor is triggering the assertion # single tensor check so it's clear which tensor is triggering the assertion
assert check_dim_for_fp8_forward_exec(tensor), ( assert check_dim_for_fp8_exec(tensor), (
"Tensor dimensions are not compatible for FP8 execution: " "Tensor dimensions are not compatible for FP8 execution: "
f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)" f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)"
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment