Unverified Commit 2f25d121 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix setting `align_size` when FP8 is not initialized (#1926)



* Fix align_size
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update docstring
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 9d031fbd
...@@ -74,11 +74,12 @@ class Fp8Padding(torch.nn.Module): ...@@ -74,11 +74,12 @@ class Fp8Padding(torch.nn.Module):
Parameters Parameters
---------- ----------
num_gemms: int num_gemms : int
number of GEMMs to be performed simutaneously. number of GEMMs to be performed simultaneously.
align_size: int, optional align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
forward pass.
""" """
def __init__( def __init__(
...@@ -89,9 +90,6 @@ class Fp8Padding(torch.nn.Module): ...@@ -89,9 +90,6 @@ class Fp8Padding(torch.nn.Module):
super().__init__() super().__init__()
self.num_gemms = num_gemms self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size self.align_size = align_size
@no_torch_dynamo() @no_torch_dynamo()
...@@ -112,6 +110,8 @@ class Fp8Padding(torch.nn.Module): ...@@ -112,6 +110,8 @@ class Fp8Padding(torch.nn.Module):
""" """
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
...@@ -72,11 +72,12 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -72,11 +72,12 @@ class Fp8Unpadding(torch.nn.Module):
Parameters Parameters
---------- ----------
num_gemms: int num_gemms : int
number of GEMMs to be performed simutaneously. number of GEMMs to be performed simultaneously.
align_size: int, optional align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others. be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
forward pass.
""" """
def __init__( def __init__(
...@@ -87,9 +88,6 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -87,9 +88,6 @@ class Fp8Unpadding(torch.nn.Module):
super().__init__() super().__init__()
self.num_gemms = num_gemms self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size self.align_size = align_size
@no_torch_dynamo() @no_torch_dynamo()
...@@ -110,6 +108,8 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -110,6 +108,8 @@ class Fp8Unpadding(torch.nn.Module):
""" """
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment