"superbench/vscode:/vscode.git/clone" did not exist on "949f9cb406a0263e45c38825b6953f3b46953c9e"
Unverified Commit 80825fde authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Don't save fp8 weight tensors if `is_first_microbatch` is None (#244)



* extend fp8 weight placeholders logic for Linear, LNLinear, LNMLP
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_mlp.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update layernorm_linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* Update layernorm_mlp.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5495883c
...@@ -561,7 +561,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -561,7 +561,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_activation_dtype(inp) self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms) self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
# Create persistent tensors for fp8 weights and their transposes
# only when fp8 weight caching is used.
if is_first_microbatch is not None:
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel: if self.fp8 and self.sequence_parallel:
...@@ -765,6 +769,50 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -765,6 +769,50 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames]) return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
def get_fp8_weights_empty_tensors(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Returns empty tensors to be later used to store fp8 version of weights
and their transposes (for the bwd pass) for this batch (or microbatch).
When `is_first_microbatch` is `None`, this is especially useful since
we then don't need to store the fp8 weights that are needed for one time
only in the forward pass. Note that we still need to store the tensor
for the fp8 weight transpose which is at least needed in the backward
pass but that's taken care of by storing the transpose tensor in
`ctx.save_for_backward`.
"""
assert is_first_microbatch is None, "Should only be here when "\
"`is_first_microbatch` is None!"
fp8_weight_tensors = []
for shape in self.fp8_weight_shapes:
fp8_weight_tensors.append(
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
)
)
fp8_weight_tensors.append(
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
)
)
return fp8_weight_tensors
@abstractmethod @abstractmethod
def forward(self): def forward(self):
"""Needs override.""" """Needs override."""
@abstractmethod
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""Needs override."""
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""LayerNormLinear API""" """LayerNormLinear API"""
import os import os
from typing import Union, Optional, Callable, Tuple, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
...@@ -791,6 +791,30 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -791,6 +791,30 @@ class LayerNormLinear(TransformerEngineBaseModule):
init.zeros_(self.layer_norm_weight) init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias) init.zeros_(self.layer_norm_bias)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]
return fp8_weight_tensors
def forward( def forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
...@@ -841,6 +865,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -841,6 +865,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
else self.noop_cat("weight_tensor", self.weight_names) else self.noop_cat("weight_tensor", self.weight_names)
) )
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
is_first_microbatch
)
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
args = [] args = []
...@@ -852,8 +881,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -852,8 +881,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
weight_tensor, weight_tensor,
self.weight1_fp8 if self.fp8 else None, weight1_fp8,
self.weight1_t_fp8 if self.fp8 else None, weight1_t_fp8,
bias_tensor, bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""LayerNormMLP API""" """LayerNormMLP API"""
import os import os
from typing import Union, Optional, Callable, Tuple, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -1063,6 +1063,31 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1063,6 +1063,31 @@ class LayerNormMLP(TransformerEngineBaseModule):
init.zeros_(self.layer_norm_weight) init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias) init.zeros_(self.layer_norm_bias)
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None, None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8,
self.weight2_fp8, self.weight2_t_fp8]
return fp8_weight_tensors
def forward( def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
...@@ -1089,6 +1114,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1089,6 +1114,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \
self.get_fp8_weights_scratchpad(
is_first_microbatch
)
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply fwd_fn = _LayerNormMLP.apply
args = [] args = []
...@@ -1100,13 +1131,13 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1100,13 +1131,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_weight, self.layer_norm_weight,
self.layer_norm_bias, self.layer_norm_bias,
self.fc1_weight, self.fc1_weight,
self.weight1_fp8 if self.fp8 else None, weight1_fp8,
self.weight1_t_fp8 if self.fp8 else None, weight1_t_fp8,
self.fc1_bias, self.fc1_bias,
self.use_bias, self.use_bias,
self.fc2_weight, self.fc2_weight,
self.weight2_fp8 if self.fp8 else None, weight2_fp8,
self.weight2_t_fp8 if self.fp8 else None, weight2_t_fp8,
self.fc2_bias, self.fc2_bias,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
from typing import Union, Optional, Callable, Tuple, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -641,6 +641,30 @@ class Linear(TransformerEngineBaseModule): ...@@ -641,6 +641,30 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
def get_fp8_weights_scratchpad(
self,
is_first_microbatch: Union[bool, None],
) -> List[torch.Tensor]:
"""
Fetch the fp8 weight tensor placeholders if they exist (when
`is_first_microbatch` is not `None`) or return empty fp8 weight
tensors (if `is_first_microbatch is None`)
"""
if not self.fp8:
return [None, None]
if is_first_microbatch is None:
# Return empty weight placeholders for each fwd/bwd pass
fp8_weight_tensors = self.get_fp8_weights_empty_tensors(
is_first_microbatch
)
else:
# These persistent weight placeholders should've been created in
# `set_fp8_weights` method
fp8_weight_tensors = [self.weight1_fp8, self.weight1_t_fp8]
return fp8_weight_tensors
def forward( def forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
...@@ -691,6 +715,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -691,6 +715,11 @@ class Linear(TransformerEngineBaseModule):
else self.noop_cat("weight_tensor", self.weight_names) else self.noop_cat("weight_tensor", self.weight_names)
) )
# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
is_first_microbatch
)
if torch.is_grad_enabled(): if torch.is_grad_enabled():
linear_fn = _Linear.apply linear_fn = _Linear.apply
args = [] args = []
...@@ -699,8 +728,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -699,8 +728,8 @@ class Linear(TransformerEngineBaseModule):
args = [None] args = [None]
args += ( args += (
weight_tensor, weight_tensor,
self.weight1_fp8 if self.fp8 else None, weight1_fp8,
self.weight1_t_fp8 if self.fp8 else None, weight1_t_fp8,
inp, inp,
bias_tensor, bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment