"vscode:/vscode.git/clone" did not exist on "b5fb5ba05cb6144f17be58b9dba8a35ba022876a"
Unverified Commit 871fdf51 authored by Carlos Mocholí's avatar Carlos Mocholí Committed by GitHub
Browse files

Clearer error messages for dtype and shape assertions (#245)



* Clearer error messages for dtype and shape assertions
Signed-off-by: default avatarCarlos Mocholí <carlossmocholi@gmail.com>

* Update transformer_engine/pytorch/utils.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCarlos Mocholí <carlossmocholi@gmail.com>

* Update transformer_engine/pytorch/utils.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCarlos Mocholí <carlossmocholi@gmail.com>

---------
Signed-off-by: default avatarCarlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 69003969
......@@ -445,25 +445,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if hasattr(self, "activation_dtype"):
return
assert all(
(
(inp.dtype == param.dtype) if param is not None else True
for param in self.parameters()
)
), (
"Data type for activations and weights must "
"match when outside of autocasted region"
)
assert all(
(
(inp.dtype == buf.dtype) if buf is not None else True
for buf in self.buffers()
)
), (
"Data type for activations and buffers must "
"match when outside of autocasted region"
)
self.activation_dtype = inp.dtype
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
for name, buf in self.named_buffers():
if buf is not None:
assert dtype == buf.dtype, (
"Data types for buffers must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {buf.dtype}"
)
self.activation_dtype = dtype
def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module as class attributes. These
......
......@@ -27,7 +27,7 @@ from ..utils import (
divide,
get_default_init_method,
cast_if_needed,
check_dim_for_fp8_forward_exec,
assert_dim_for_fp8_forward_exec,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -94,9 +94,9 @@ class _LayerNormLinear(torch.autograd.Function):
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Input and weight dimensions are not compatible for FP8 execution."
if fp8:
assert_dim_for_fp8_forward_exec(inputmat)
assert_dim_for_fp8_forward_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
......@@ -32,7 +32,7 @@ from ..utils import (
divide,
get_default_init_method,
cast_if_needed,
check_dim_for_fp8_forward_exec,
assert_dim_for_fp8_forward_exec,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -107,9 +107,10 @@ class _LayerNormMLP(torch.autograd.Function):
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_dim_for_fp8_forward_exec(inputmat, fc1_weight, fc2_weight)
), "Input and weight dimensions are not compatible for FP8 execution."
if fp8:
assert_dim_for_fp8_forward_exec(inputmat)
assert_dim_for_fp8_forward_exec(fc1_weight)
assert_dim_for_fp8_forward_exec(fc2_weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
......@@ -24,7 +24,7 @@ from ..utils import (
divide,
get_default_init_method,
cast_if_needed,
check_dim_for_fp8_forward_exec,
assert_dim_for_fp8_forward_exec,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -80,9 +80,9 @@ class _Linear(torch.autograd.Function):
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Input and weight dimensions are not compatible for FP8 execution."
if fp8:
assert_dim_for_fp8_forward_exec(inputmat)
assert_dim_for_fp8_forward_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
......@@ -179,8 +179,19 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
def check_dim_for_fp8_forward_exec(*tensors: Tuple[torch.Tensor, ...]) -> bool:
def check_dim_for_fp8_forward_exec(tensor: torch.Tensor) -> bool:
"""For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16.
"""
return all(not t.shape[0] % 8 and not t.shape[1] % 16 for t in tensors)
return not tensor.shape[0] % 8 and not tensor.shape[1] % 16
def assert_dim_for_fp8_forward_exec(tensor: torch.Tensor) -> None:
"""For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16.
"""
# single tensor check so it's clear which tensor is triggering the assertion
assert check_dim_for_fp8_forward_exec(tensor), (
"Tensor dimensions are not compatible for FP8 execution: "
f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment