Unverified Commit d1d00b3e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Relax dimension checks for fp8 exec (#106)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 44d64abc
...@@ -51,7 +51,7 @@ from .utils import ( ...@@ -51,7 +51,7 @@ from .utils import (
divide, divide,
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
check_modulo_16, check_dim_for_fp8_forward_exec,
) )
from .distributed import ( from .distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -666,8 +666,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -666,8 +666,8 @@ class _LayerNormLinear(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
assert ( assert (
not fp8 or check_modulo_16(inputmat, weight) not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Inputs and weights must be divisible by 16 for FP8 execution." ), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
...@@ -1396,8 +1396,8 @@ class _Linear(torch.autograd.Function): ...@@ -1396,8 +1396,8 @@ class _Linear(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
assert ( assert (
not fp8 or check_modulo_16(inputmat, weight) not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Inputs and weights must be divisible by 16 for FP8 execution." ), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
...@@ -2012,8 +2012,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2012,8 +2012,8 @@ class _LayerNormMLP(torch.autograd.Function):
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
assert ( assert (
not fp8 or check_modulo_16(inputmat, fc1_weight, fc2_weight) not fp8 or check_dim_for_fp8_forward_exec(inputmat, fc1_weight, fc2_weight)
), "Inputs and weights must be divisible by 16 for FP8 execution." ), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
...@@ -179,6 +179,8 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: ...@@ -179,6 +179,8 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype) return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
def check_modulo_16(*tensors: Tuple[torch.Tensor, ...]) -> bool: def check_dim_for_fp8_forward_exec(*tensors: Tuple[torch.Tensor, ...]) -> bool:
"""Check if each dimension of given tensors is divisible by 16.""" """For fp8 fprop (TN layout), inputs and weights must be such
return all(all(n % 16 == 0 for n in t.shape) for t in tensors) that dim0 is divisible by 8 and dim1 is divisible by 16.
"""
return all(not t.shape[0] % 8 and not t.shape[1] % 16 for t in tensors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment