"vscode:/vscode.git/clone" did not exist on "dbed904cefbf7d201d77dde4f7ace89e4fc23d1b"
Unverified Commit 871fdf51 authored by Carlos Mocholí's avatar Carlos Mocholí Committed by GitHub
Browse files

Clearer error messages for dtype and shape assertions (#245)



* Clearer error messages for dtype and shape assertions
Signed-off-by: default avatarCarlos Mocholí <carlossmocholi@gmail.com>

* Update transformer_engine/pytorch/utils.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarCarlos Mocholí <carlossmocholi@gmail.com>

* Update transformer_engine/pytorch/utils.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarCarlos Mocholí <carlossmocholi@gmail.com>

---------
Signed-off-by: default avatarCarlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 69003969
...@@ -445,25 +445,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -445,25 +445,20 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if hasattr(self, "activation_dtype"): if hasattr(self, "activation_dtype"):
return return
assert all( dtype = inp.dtype
( for name, param in self.named_parameters():
(inp.dtype == param.dtype) if param is not None else True if param is not None:
for param in self.parameters() assert dtype == param.dtype, (
) "Data types for parameters must match when outside of autocasted region. "
), ( f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
"Data type for activations and weights must " )
"match when outside of autocasted region" for name, buf in self.named_buffers():
) if buf is not None:
assert all( assert dtype == buf.dtype, (
( "Data types for buffers must match when outside of autocasted region. "
(inp.dtype == buf.dtype) if buf is not None else True f" Found input dtype: {dtype} and {name!r} dtype: {buf.dtype}"
for buf in self.buffers() )
) self.activation_dtype = dtype
), (
"Data type for activations and buffers must "
"match when outside of autocasted region"
)
self.activation_dtype = inp.dtype
def set_fp8_weights(self) -> None: def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module as class attributes. These """Initializes FP8 weights for the module as class attributes. These
......
...@@ -27,7 +27,7 @@ from ..utils import ( ...@@ -27,7 +27,7 @@ from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
check_dim_for_fp8_forward_exec, assert_dim_for_fp8_forward_exec,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -94,9 +94,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -94,9 +94,9 @@ class _LayerNormLinear(torch.autograd.Function):
in_features = ln_weight.numel() in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
assert ( if fp8:
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight) assert_dim_for_fp8_forward_exec(inputmat)
), "Input and weight dimensions are not compatible for FP8 execution." assert_dim_for_fp8_forward_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
...@@ -32,7 +32,7 @@ from ..utils import ( ...@@ -32,7 +32,7 @@ from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
check_dim_for_fp8_forward_exec, assert_dim_for_fp8_forward_exec,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -107,9 +107,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -107,9 +107,10 @@ class _LayerNormMLP(torch.autograd.Function):
in_features = ln_weight.numel() in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
assert ( if fp8:
not fp8 or check_dim_for_fp8_forward_exec(inputmat, fc1_weight, fc2_weight) assert_dim_for_fp8_forward_exec(inputmat)
), "Input and weight dimensions are not compatible for FP8 execution." assert_dim_for_fp8_forward_exec(fc1_weight)
assert_dim_for_fp8_forward_exec(fc2_weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
...@@ -24,7 +24,7 @@ from ..utils import ( ...@@ -24,7 +24,7 @@ from ..utils import (
divide, divide,
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
check_dim_for_fp8_forward_exec, assert_dim_for_fp8_forward_exec,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -80,9 +80,9 @@ class _Linear(torch.autograd.Function): ...@@ -80,9 +80,9 @@ class _Linear(torch.autograd.Function):
in_features = weight.shape[-1] in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible" assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
assert ( if fp8:
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight) assert_dim_for_fp8_forward_exec(inputmat)
), "Input and weight dimensions are not compatible for FP8 execution." assert_dim_for_fp8_forward_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
......
...@@ -179,8 +179,19 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: ...@@ -179,8 +179,19 @@ def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype) return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
def check_dim_for_fp8_forward_exec(*tensors: Tuple[torch.Tensor, ...]) -> bool: def check_dim_for_fp8_forward_exec(tensor: torch.Tensor) -> bool:
"""For fp8 fprop (TN layout), inputs and weights must be such """For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16. that dim0 is divisible by 8 and dim1 is divisible by 16.
""" """
return all(not t.shape[0] % 8 and not t.shape[1] % 16 for t in tensors) return not tensor.shape[0] % 8 and not tensor.shape[1] % 16
def assert_dim_for_fp8_forward_exec(tensor: torch.Tensor) -> None:
"""For fp8 fprop (TN layout), inputs and weights must be such
that dim0 is divisible by 8 and dim1 is divisible by 16.
"""
# single tensor check so it's clear which tensor is triggering the assertion
assert check_dim_for_fp8_forward_exec(tensor), (
"Tensor dimensions are not compatible for FP8 execution: "
f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment