Unverified Commit 2f25d121 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix setting `align_size` when FP8 is not initialized (#1926)



* Fix align_size
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update docstring
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 9d031fbd
......@@ -74,11 +74,12 @@ class Fp8Padding(torch.nn.Module):
Parameters
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
align_size: int, optional
num_gemms : int
number of GEMMs to be performed simultaneously.
align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
forward pass.
"""
def __init__(
......@@ -89,9 +90,6 @@ class Fp8Padding(torch.nn.Module):
super().__init__()
self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size
@no_torch_dynamo()
......@@ -112,6 +110,8 @@ class Fp8Padding(torch.nn.Module):
"""
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
# FP8 padding calculate
padded_m_splits = [
......
......@@ -72,11 +72,12 @@ class Fp8Unpadding(torch.nn.Module):
Parameters
----------
num_gemms: int
number of GEMMs to be performed simutaneously.
align_size: int, optional
num_gemms : int
number of GEMMs to be performed simultaneously.
align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe, 32 for MXFP8 and 16 for others.
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
forward pass.
"""
def __init__(
......@@ -87,9 +88,6 @@ class Fp8Unpadding(torch.nn.Module):
super().__init__()
self.num_gemms = num_gemms
if align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
else:
self.align_size = align_size
@no_torch_dynamo()
......@@ -110,6 +108,8 @@ class Fp8Unpadding(torch.nn.Module):
"""
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16
# FP8 padding calculate
padded_m_splits = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment