Unverified Commit d1d00b3e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Relax dimension checks for fp8 exec (#106)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 44d64abc
......@@ -51,7 +51,7 @@ from .utils import (
divide,
get_default_init_method,
cast_if_needed,
check_modulo_16,
check_dim_for_fp8_forward_exec,
)
from .distributed import (
set_tensor_model_parallel_attributes,
......@@ -666,8 +666,8 @@ class _LayerNormLinear(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_modulo_16(inputmat, weight)
), "Inputs and weights must be divisible by 16 for FP8 execution."
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......@@ -1396,8 +1396,8 @@ class _Linear(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_modulo_16(inputmat, weight)
), "Inputs and weights must be divisible by 16 for FP8 execution."
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......@@ -2012,8 +2012,8 @@ class _LayerNormMLP(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_modulo_16(inputmat, fc1_weight, fc2_weight)
), "Inputs and weights must be divisible by 16 for FP8 execution."
not fp8 or check_dim_for_fp8_forward_exec(inputmat, fc1_weight, fc2_weight)
), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
......@@ -179,6 +179,8 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
def check_modulo_16(*tensors: Tuple[torch.Tensor, ...]) -> bool:
"""Check if each dimension of given tensors is divisible by 16."""
return all(all(n % 16 == 0 for n in t.shape) for t in tensors)
def check_dim_for_fp8_forward_exec(*tensors: Tuple[torch.Tensor, ...]) -> bool:
"""For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16.
"""
return all(not t.shape[0] % 8 and not t.shape[1] % 16 for t in tensors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment