Unverified Commit 80825fde authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Don't save fp8 weight tensors if `is_first_microbatch` is None (#244)



* extend fp8 weight placeholders logic for Linear, LNLinear, LNMLP
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_mlp.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update layernorm_linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update layernorm_mlp.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5495883c
......@@ -561,7 +561,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used.
if is_first_microbatch is not None:
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel:
......@@ -765,6 +769,50 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Returns empty tensors to be later used to store fp8 version of weights
and their transposes (for the bwd pass) for this batch (or microbatch).
When `is_first_microbatch` is `None`, this is especially useful since
we then don't need to store the fp8 weights that are needed for one time
only in the forward pass. Note that we still need to store the tensor
for the fp8 weight transpose which is at least needed in the backward
pass but that's taken care of by storing the transpose tensor in
`ctx.save_for_backward`.
"""
assert is_first_microbatch is None, "Should only be here when "\
"`is_first_microbatch` is None!"
fp8_weight_tensors = []
for shape in self.fp8_weight_shapes:
fp8_weight_tensors.append(
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
)
)
fp8_weight_tensors.append(
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
)
)
return fp8_weight_tensors
@abstractmethod
def forward(self):
"""Needs override."""
@abstractmethod
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""Needs override."""
......@@ -4,7 +4,7 @@
"""LayerNormLinear API"""
import os
from typing import Union, Optional, Callable, Tuple, Dict, Any
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
......@@ -791,6 +791,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]
return fp8_weight_tensors
def forward(
self,
inp: torch.Tensor,
......@@ -841,6 +865,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
else self.noop_cat("weight_tensor", self.weight_names)
)
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
is_first_microbatch
)
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
......@@ -852,8 +881,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
weight1_fp8,
weight1_t_fp8,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
......
......@@ -4,7 +4,7 @@
"""LayerNormMLP API"""
import os
from typing import Union, Optional, Callable, Tuple, Dict, Any
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
from torch.nn.parameter import Parameter
......@@ -1063,6 +1063,31 @@ class LayerNormMLP(TransformerEngineBaseModule):
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None, None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8,
self.weight2_fp8, self.weight2_t_fp8]
return fp8_weight_tensors
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
......@@ -1089,6 +1114,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \
self.get_fp8_weights_scratchpad(
is_first_microbatch
)
if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply
args = []
......@@ -1100,13 +1131,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_weight,
self.layer_norm_bias,
self.fc1_weight,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
weight1_fp8,
weight1_t_fp8,
self.fc1_bias,
self.use_bias,
self.fc2_weight,
self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None,
weight2_fp8,
weight2_t_fp8,
self.fc2_bias,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Linear API"""
from typing import Union, Optional, Callable, Tuple, Dict, Any
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
from torch.nn.parameter import Parameter
......@@ -641,6 +641,30 @@ class Linear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]
return fp8_weight_tensors
def forward(
self,
inp: torch.Tensor,
......@@ -691,6 +715,11 @@ class Linear(TransformerEngineBaseModule):
else self.noop_cat("weight_tensor", self.weight_names)
)
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
is_first_microbatch
)
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
......@@ -699,8 +728,8 @@ class Linear(TransformerEngineBaseModule):
args = [None]
args += (
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
weight1_fp8,
weight1_t_fp8,
inp,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment