Unverified Commit cd54a8cd authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Misc fixes for release_v1.6 (#784)



* fixes; docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Check for FP8
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix LoRa-like use cases
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7c1828f8
...@@ -28,6 +28,7 @@ from ..utils import ( ...@@ -28,6 +28,7 @@ from ..utils import (
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
requires_grad,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -355,7 +356,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -355,7 +356,11 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors or
FP8GlobalStateManager.is_first_fp8_module())
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -699,7 +704,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -699,7 +704,7 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
wgrad = None wgrad = None
if ctx.is_first_module and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
......
...@@ -33,6 +33,7 @@ from ..utils import ( ...@@ -33,6 +33,7 @@ from ..utils import (
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
requires_grad,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -544,7 +545,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -544,7 +545,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.primary_weights_in_fp8 = primary_weights_in_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(
inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
...@@ -1121,7 +1125,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1121,7 +1125,7 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
fc2_wgrad = None fc2_wgrad = None
if ctx.is_first_module and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
......
...@@ -26,6 +26,7 @@ from ..utils import ( ...@@ -26,6 +26,7 @@ from ..utils import (
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
init_method_constant, init_method_constant,
requires_grad,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -363,7 +364,11 @@ class _Linear(torch.autograd.Function): ...@@ -363,7 +364,11 @@ class _Linear(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.primary_weights_in_fp8 = primary_weights_in_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weight, bias):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors or
FP8GlobalStateManager.is_first_fp8_module())
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
...@@ -381,7 +386,7 @@ class _Linear(torch.autograd.Function): ...@@ -381,7 +386,7 @@ class _Linear(torch.autograd.Function):
def backward( def backward(
ctx, grad_output: torch.Tensor ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad_output[0], Float8Tensor): if isinstance(grad_output, Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[ ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_output._scale_inv tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_output._scale_inv
...@@ -611,7 +616,7 @@ class _Linear(torch.autograd.Function): ...@@ -611,7 +616,7 @@ class _Linear(torch.autograd.Function):
else: else:
wgrad = None wgrad = None
if ctx.is_first_module and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
...@@ -954,8 +959,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -954,8 +959,6 @@ class Linear(TransformerEngineBaseModule):
* it also allows skipping gradient accumulation during the * it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
is_first_module_in_mha: Optional[bool], default = False
Whether to output in FP8. By default, Linear outputs in inp.dtype.
""" """
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
......
...@@ -8,6 +8,14 @@ from typing import Any, Callable, Optional, Tuple ...@@ -8,6 +8,14 @@ from typing import Any, Callable, Optional, Tuple
import torch import torch
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient."""
for tensor in tensors:
if tensor is not None and tensor.requires_grad:
return True
return False
def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
""" """
Trick to deallocate tensor memory when delete operation does not Trick to deallocate tensor memory when delete operation does not
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment