Unverified Commit d74ee5b5 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Catch FP8 modulo16 error before cublas and fp8 kernels (#97)



* Catch FP8 modulo16 error before cublas and fp8 kernels
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* annotate
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 66055973
......@@ -51,6 +51,7 @@ from .utils import (
divide,
get_default_init_method,
cast_if_needed,
check_modulo_16,
)
from .distributed import (
set_tensor_model_parallel_attributes,
......@@ -664,6 +665,9 @@ class _LayerNormLinear(torch.autograd.Function):
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_modulo_16(inputmat, weight)
), "Inputs and weights must be divisible by 16 for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......@@ -1391,6 +1395,9 @@ class _Linear(torch.autograd.Function):
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_modulo_16(inputmat, weight)
), "Inputs and weights must be divisible by 16 for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......@@ -2004,6 +2011,9 @@ class _LayerNormMLP(torch.autograd.Function):
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_modulo_16(inputmat, fc1_weight, fc2_weight)
), "Inputs and weights must be divisible by 16 for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
......@@ -177,3 +177,8 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Cast tensor to dtype"""
with torch.enable_grad():
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
def check_modulo_16(*tensors: Tuple[torch.Tensor, ...]) -> bool:
"""Check if each dimension of given tensors is divisible by 16."""
return all(all(n % 16 == 0 for n in t.shape) for t in tensors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment