Unverified Commit 45a2ac41 authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Support recompute (#412)



* Add recompute
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Support recompute core attention
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix transformer layer recompute
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add doc
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Improve recompute test
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Improve performance of stack backtrace
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Improve code stype
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix code style
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
parent a5dbf1e2
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TransformerLayer encoder recompute"""
import sys
import paddle
import transformer_engine.paddle as te
class Net(paddle.nn.Layer):
"""Network use for recompute testing"""
def __init__(self, layers):
super().__init__()
self.layers = layers
def forward(self, inp, mask, enable_recompute, use_reentrant):
for layer in self.layers:
if enable_recompute:
out = te.recompute(layer, inp, mask, use_reentrant=use_reentrant)
else:
out = layer(inp, mask)
return out
def main():
"""Main function"""
paddle.seed(10)
batch_size = 16
hidden_size = 4096
num_heads = 32
ffn_hidden_size = 16384
q_seqlen = 512
kv_seqlen = 512
num_layers = 4
enable_recompute = int(sys.argv[1])
use_reentrant = int(sys.argv[2])
layers = paddle.nn.LayerList([
te.TransformerLayer(
hidden_size,
ffn_hidden_size,
num_heads,
layer_type='encoder',
) for _ in range(num_layers)
])
model = Net(layers)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters())
for _ in range(10):
inp = paddle.uniform([batch_size, q_seqlen, hidden_size])
inp.stop_gradient = False
mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype='bool')
with te.fp8_autocast(enabled=True):
out = model(inp, mask, enable_recompute, use_reentrant)
loss = out.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
print("Loss: ", float(loss))
print("Peak memory: ", paddle.device.cuda.max_memory_allocated(0))
if __name__ == "__main__":
main()
...@@ -14,12 +14,18 @@ import transformer_engine.paddle as te ...@@ -14,12 +14,18 @@ import transformer_engine.paddle as te
from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
paddle.seed(10)
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
LINEAR_CASES = [(16, 16, 32), (32, 32, 64)] LINEAR_CASES = [(16, 16, 32), (32, 32, 64)]
NORM_CASES = [(16, 32), (256, 1024)] NORM_CASES = [(16, 32), (256, 1024)]
@pytest.fixture(autouse=True)
def setup():
"""Setup random seed before each test"""
paddle.seed(10)
yield
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('use_fp8', [True, False]) @pytest.mark.parametrize('use_fp8', [True, False])
def test_checkpoint(use_fp8): def test_checkpoint(use_fp8):
...@@ -897,9 +903,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -897,9 +903,11 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False]) @pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('recompute_core_attention', [True, False])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias, def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype, no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output): output_layernorm, return_layernorm_output,
recompute_core_attention):
""" """
Test Transformer Decoder Layer Test Transformer Decoder Layer
""" """
...@@ -1049,18 +1057,33 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1049,18 +1057,33 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te.layernorm.weight.stop_gradient = no_wgrad layer_te.layernorm.weight.stop_gradient = no_wgrad
layer_te.layernorm.bias.stop_gradient = no_dbias layer_te.layernorm.bias.stop_gradient = no_dbias
def calc_transformer_output_and_grad(layer, encoder_input, mask, encoder_output, def calc_transformer_output_and_grad(layer,
enc_dec_attn_mask, dout): encoder_input,
mask,
encoder_output,
enc_dec_attn_mask,
dout,
recompute_core_attention=False):
_encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
_encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False) _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False)
out = layer(_encoder_input, mask, _encoder_output, enc_dec_attn_mask) out = layer(_encoder_input,
mask,
_encoder_output,
enc_dec_attn_mask,
recompute_core_attention=recompute_core_attention)
out.backward(dout) out.backward(dout)
return out, _encoder_input.grad, _encoder_output.grad return out, _encoder_input.grad, _encoder_output.grad
out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad( out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad(
layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out) layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out)
out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad( out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad(
layer_te, encoder_input, attn_mask, encoder_output, attn_mask, grad_out) layer_te,
encoder_input,
attn_mask,
encoder_output,
attn_mask,
grad_out,
recompute_core_attention=recompute_core_attention)
assert_allclose(out, out_ref, rtol=rtol, atol=atol) assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol) assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol)
......
...@@ -45,8 +45,6 @@ from transformer_engine.paddle.fp8 import is_fp8_available ...@@ -45,8 +45,6 @@ from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling
np.random.seed(10)
paddle.seed(11)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)] (16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
...@@ -57,6 +55,14 @@ FLASH_ATTN_CASES = [(4, 1024, 16, 64), (2, 2048, 16, 128)] ...@@ -57,6 +55,14 @@ FLASH_ATTN_CASES = [(4, 1024, 16, 64), (2, 2048, 16, 128)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16] ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
@pytest.fixture(autouse=True)
def setup():
"""Setup random seed before each test"""
np.random.seed(10)
paddle.seed(11)
yield
def test_quantize_dequantize(): def test_quantize_dequantize():
""" """
Test cast_to_fp8 and cast_from_fp8 Test cast_to_fp8 and cast_from_fp8
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TE Paddle Recompute"""
from pathlib import Path
import re
import subprocess
import numpy as np
import pytest
from transformer_engine.paddle.fp8 import is_fp8_available
test_root = Path(__file__).resolve().parent
is_fp8_supported, reason = is_fp8_available()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('use_reentrant', [False, True])
def test_transformer_encoder_recompute(use_reentrant):
"""
Test TransformerLayer encoder recompute
"""
rtol = 1e-5
atol = 1e-5
def launch_subprocess_and_check_output(enable_recompute):
"""Launch training in subprocess and check output"""
try:
cmd = [
'python',
str(test_root / 'recompute_tests' / 'recompute_transformer_encoder.py'),
str(int(enable_recompute)),
str(int(use_reentrant))
]
result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True)
print(result)
loss_match = re.search(r'Loss:\s+(-?\d+\.\d+)', result)
memory_match = re.search(r'Peak memory:\s+(\d+)', result)
loss_value = float(loss_match.group(1))
memory_value = int(memory_match.group(1))
return loss_value, memory_value
except subprocess.CalledProcessError as e:
raise ValueError(f"Subprocess failed with error: {e}") from e
loss_recompute, peak_memory_recompute = launch_subprocess_and_check_output(True)
loss_ref, peak_memory_ref = launch_subprocess_and_check_output(False)
assert peak_memory_recompute < peak_memory_ref
np.testing.assert_allclose(loss_recompute, loss_ref, rtol=rtol, atol=atol)
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .layer import (Linear, LayerNorm, LayerNormLinear, LayerNormMLP, FusedScaleMaskSoftmax, from .layer import (Linear, LayerNorm, LayerNormLinear, LayerNormMLP, FusedScaleMaskSoftmax,
DotProductAttention, MultiHeadAttention, TransformerLayer) DotProductAttention, MultiHeadAttention, TransformerLayer)
from .recompute import recompute
...@@ -50,3 +50,5 @@ LayerTypes = ("encoder", "decoder") ...@@ -50,3 +50,5 @@ LayerTypes = ("encoder", "decoder")
GemmParallelModes = ("row", "column", None) GemmParallelModes = ("row", "column", None)
dist_group_type = paddle.distributed.collective.Group dist_group_type = paddle.distributed.collective.Group
RecomputeFunctionNames = ('unpack', 'backward')
...@@ -13,7 +13,7 @@ import transformer_engine_paddle as tex ...@@ -13,7 +13,7 @@ import transformer_engine_paddle as tex
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type from .constants import dist_group_type
from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer
# FP8 support # FP8 support
_is_fp8_available = None _is_fp8_available = None
...@@ -59,8 +59,10 @@ class FP8State: ...@@ -59,8 +59,10 @@ class FP8State:
self._is_first_fp8_module = False self._is_first_fp8_module = False
self._fp8_autocast_counter = 0 self._fp8_autocast_counter = 0
self._fp8_autocast_depth = 0 self._fp8_autocast_depth = 0
self._fp8_recompute_enabled = False
self._fp8_fwd_buffer = FP8MetaFwdBuffer() self._fp8_fwd_buffer = FP8MetaFwdBuffer()
self._fp8_bwd_buffer = FP8MetaBwdBuffer() self._fp8_bwd_buffer = FP8MetaBwdBuffer()
self._fp8_recompute_buffer = FP8RecomputeBuffer()
def is_fp8_enabled(self) -> bool: def is_fp8_enabled(self) -> bool:
"""Is FP8 enabled""" """Is FP8 enabled"""
...@@ -106,6 +108,14 @@ class FP8State: ...@@ -106,6 +108,14 @@ class FP8State:
"""Returns global fp8 backward buffer.""" """Returns global fp8 backward buffer."""
return self._fp8_bwd_buffer return self._fp8_bwd_buffer
def is_fp8_recompute_enabled(self) -> bool:
"""Is FP8 recompute enabled"""
return self._fp8_recompute_enabled
def get_fp8_recompute_buffer(self) -> FP8RecomputeBuffer:
"""Returns global fp8 recompute buffer."""
return self._fp8_recompute_buffer
def enter( def enter(
self, self,
enabled: bool, enabled: bool,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""FP8 meta buffer for FP8 amax reduction""" """FP8 meta buffer for FP8 amax reduction"""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque
from functools import partial from functools import partial
import os import os
from typing import Dict, Any, List, Union from typing import Dict, Any, List, Union
...@@ -11,7 +12,7 @@ from typing import Dict, Any, List, Union ...@@ -11,7 +12,7 @@ from typing import Dict, Any, List, Union
import numpy as np import numpy as np
import paddle import paddle
from .constants import dist_group_type from .constants import dist_group_type, RecomputeFunctionNames
class FP8MetaBufferBase(ABC): class FP8MetaBufferBase(ABC):
...@@ -255,3 +256,60 @@ class FP8MetaBwdBuffer(FP8MetaBufferBase): ...@@ -255,3 +256,60 @@ class FP8MetaBwdBuffer(FP8MetaBufferBase):
""" """
self._amax_reduce_wait_func = self._global_amax_reduction(fp8_meta, tp_group, tp_size) self._amax_reduce_wait_func = self._global_amax_reduction(fp8_meta, tp_group, tp_size)
self._execute_deletion() self._execute_deletion()
class FP8RecomputeBuffer:
"""Buffer used to hold FP8 meta tensors for recompute"""
def __init__(self):
self._data = []
@staticmethod
def get_buffer_position_key():
"""Returns the key (in fp8_meta) for recompute buffer position"""
return 'recompute_buffer_pos'
def stash_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None:
"""Stash the scaling factors and amaxes for recompute"""
buffer_position_key = self.get_buffer_position_key()
to_copy = [
fp8_meta["scaling_fwd"].amax_history.clone(),
fp8_meta["scaling_fwd"].scale.clone(),
fp8_meta["scaling_fwd"].scale_inv.clone(),
]
if buffer_position_key in fp8_meta:
self._data[fp8_meta[buffer_position_key]].append(to_copy)
else:
self._data.append(deque())
self._data[-1].append(to_copy)
fp8_meta[buffer_position_key] = len(self._data) - 1
def retrieve_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None:
"""Switch to the previously saved scaling factors and amaxes"""
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = self.get_buffer_position_key()
stashed_fp8_meta = self._data[fp8_meta[buffer_position_key]].popleft()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
assert "updated_amax_history_fwd" in fp8_meta, "Recompute internal error." \
" If you are not using recompute, please check if" \
" the forward function is called from one of these functions: " \
f"{RecomputeFunctionNames}. If so, consider change the function name " \
"or set NVTE_DISABLE_RECOMPUTE=1."
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
...@@ -23,6 +23,7 @@ from ..cpp_extensions import ( ...@@ -23,6 +23,7 @@ from ..cpp_extensions import (
) )
from ..distributed import get_tp_group_and_world_size, track_rng_state from ..distributed import get_tp_group_and_world_size, track_rng_state
from ..utils import attention_mask_func, divide, mask_to_cu_seqlens from ..utils import attention_mask_func, divide, mask_to_cu_seqlens
from ..recompute import recompute
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
...@@ -383,6 +384,22 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -383,6 +384,22 @@ class MultiHeadAttention(paddle.nn.Layer):
whether to zero initialize the gamma of the layernorm operation. whether to zero initialize the gamma of the layernorm operation.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for attention operation. backend to use for attention operation.
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
if set to `True`, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
rng_state_name : str, default = `local_seed`
Controls the rng state used for dropout on attention probs. The
specified rng should be set different seeds for different TP ranks.
It will be ignored if `set_parallel_mode` is False. The specified
name should be registered through
`paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
.add(rng_state_name, seed)`.
""" """
def __init__( def __init__(
...@@ -516,6 +533,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -516,6 +533,7 @@ class MultiHeadAttention(paddle.nn.Layer):
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None, core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True, set_zero: bool = True,
recompute_core_attention: bool = False,
) -> Tuple[Union[paddle.Tensor, None], ...]: ) -> Tuple[Union[paddle.Tensor, None], ...]:
""" """
MultiHeadAttention Layer. MultiHeadAttention Layer.
...@@ -535,7 +553,11 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -535,7 +553,11 @@ class MultiHeadAttention(paddle.nn.Layer):
Bias tensor for Q * K.T Bias tensor for Q * K.T
set_zero: bool, defautl = `True` set_zero: bool, defautl = `True`
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
recompute_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
""" """
# hidden_states: [b, s_q, hidden_size] # hidden_states: [b, s_q, hidden_size]
...@@ -558,14 +580,26 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -558,14 +580,26 @@ class MultiHeadAttention(paddle.nn.Layer):
]) ])
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
context_layer = self.core_attention( if recompute_core_attention:
query_layer=mixed_qkv_layer, context_layer = recompute(
key_value_layer=None, self.core_attention,
attention_mask=attention_mask, mixed_qkv_layer,
core_attention_bias_type=core_attention_bias_type, None,
core_attention_bias=core_attention_bias, attention_mask,
set_zero=set_zero, core_attention_bias_type,
) core_attention_bias,
set_zero,
use_reentrant=False,
)
else:
context_layer = self.core_attention(
query_layer=mixed_qkv_layer,
key_value_layer=None,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
else: # cross attention else: # cross attention
mixed_kv_layer = self.key_value(encoder_output) mixed_kv_layer = self.key_value(encoder_output)
...@@ -587,14 +621,26 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -587,14 +621,26 @@ class MultiHeadAttention(paddle.nn.Layer):
0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head 0, 0, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
]) ])
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
context_layer = self.core_attention( if recompute_core_attention:
query_layer=query_layer, context_layer = recompute(
key_value_layer=mixed_kv_layer, self.core_attention,
attention_mask=attention_mask, query_layer,
core_attention_bias_type=core_attention_bias_type, mixed_kv_layer,
core_attention_bias=core_attention_bias, attention_mask,
set_zero=set_zero, core_attention_bias_type,
) core_attention_bias,
set_zero,
use_reentrant=False,
)
else:
context_layer = self.core_attention(
query_layer=query_layer,
key_value_layer=mixed_kv_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
context_layer = paddle.reshape(context_layer, context_layer = paddle.reshape(context_layer,
[0, 0, context_layer.shape[2] * context_layer.shape[3]]) [0, 0, context_layer.shape[2] * context_layer.shape[3]])
......
...@@ -25,6 +25,8 @@ from ..fp8 import ( ...@@ -25,6 +25,8 @@ from ..fp8 import (
get_fp8_te_dtype, get_fp8_te_dtype,
) )
from ..profile import nvtx_range from ..profile import nvtx_range
from ..recompute import is_in_recompute_phase
from ..fp8_buffer import FP8RecomputeBuffer
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
...@@ -199,6 +201,9 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -199,6 +201,9 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[ self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[
0] 0]
recompute_buffer_pos_key = FP8RecomputeBuffer.get_buffer_position_key()
if recompute_buffer_pos_key in self.fp8_meta:
del self.fp8_meta[recompute_buffer_pos_key]
@paddle.no_grad() @paddle.no_grad()
def set_state_dict(self, state_dict, use_structured_name=True): def set_state_dict(self, state_dict, use_structured_name=True):
...@@ -221,34 +226,48 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): ...@@ -221,34 +226,48 @@ class TransformerEngineBaseLayer(paddle.nn.Layer, ABC):
just in case. The autocast exit will pick up the most recent one. just in case. The autocast exit will pick up the most recent one.
""" """
self.set_activation_dtype(inp) if self.fp8_enabled and is_in_recompute_phase():
self.fp8_init(num_gemms=num_gemms) global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer()
global_recompute_buffer.retrieve_fp8_meta_tensors(self.fp8_meta)
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer()
global_fp8_fwd_buffer.wait()
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
amax_and_scale_update(self.fp8_meta, True)
global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
else:
amax_and_scale_update(self.fp8_meta, True)
if self.fp8_enabled and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_state = get_global_fp8_state()
self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module()
self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id()
self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"])
self.fp8_meta["update_amax_and_scale_fwd"] = True
else: else:
self.fp8_meta["update_amax_and_scale_fwd"] = False self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer()
global_fp8_fwd_buffer.wait()
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta)
amax_and_scale_update(self.fp8_meta, True)
global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta)
else:
amax_and_scale_update(self.fp8_meta, True)
if self.fp8_enabled and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
global_fp8_state = get_global_fp8_state()
self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module()
self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id()
self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"])
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase.
if (self.fp8_enabled and self.training
and get_global_fp8_state().is_fp8_recompute_enabled()):
global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer()
global_recompute_buffer.stash_fp8_meta_tensors(self.fp8_meta)
with nvtx_range(self.__class__.__name__ + " forward"): with nvtx_range(self.__class__.__name__ + " forward"):
yield inp yield inp
if self.fp8_enabled and is_in_recompute_phase():
FP8RecomputeBuffer.restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax: if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax:
global_fp8_state = get_global_fp8_state() global_fp8_state = get_global_fp8_state()
global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer() global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer()
......
...@@ -36,6 +36,8 @@ from ..utils import ( ...@@ -36,6 +36,8 @@ from ..utils import (
cast_if_needed, cast_if_needed,
cast_if_needed_inplace, cast_if_needed_inplace,
divide, divide,
save_for_backward_allow_none,
saved_tensor_allow_none,
) )
__all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"] __all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"]
...@@ -193,7 +195,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -193,7 +195,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
) )
if is_grad_enabled: if is_grad_enabled:
ctx.save_for_backward( save_for_backward_allow_none(
ctx,
inputmat, inputmat,
ln_weight, ln_weight,
mu, mu,
...@@ -217,8 +220,10 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -217,8 +220,10 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient ctx.requires_ln_bgrad = not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
...@@ -235,7 +240,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -235,7 +240,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.tp_group, ctx.tp_group,
ctx.tp_size, ctx.tp_size,
name="_LayerNormLinear"): name="_LayerNormLinear"):
( ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
ln_weight, ln_weight,
mu, mu,
...@@ -244,7 +249,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -244,7 +249,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
weight_t_fp8, weight_t_fp8,
ln_out, ln_out,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensor() ) = saved_tensor_allow_none(ctx)
( (
grad_output, grad_output,
...@@ -258,7 +263,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -258,7 +263,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
if ctx.fp8_enabled: if ctx.fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_wgrad = not ctx.fp8_meta["recipe"].override_linear_precision.wgrad fp8_wgrad = not ctx.fp8_meta["recipe"].override_linear_precision.wgrad
if not weight.stop_gradient: if ctx.requires_wgrad:
if fp8_wgrad: if fp8_wgrad:
ln_out_t = transpose(ln_out, fp8_dtype_forward) ln_out_t = transpose(ln_out, fp8_dtype_forward)
else: else:
...@@ -287,6 +292,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -287,6 +292,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.fp8_enabled, ctx.fp8_enabled,
ctx.fp8_meta, ctx.fp8_meta,
True, # Always compute dgrad to feed into LayerNorm bwd True, # Always compute dgrad to feed into LayerNorm bwd
ctx.requires_wgrad,
ctx.activation_dtype, ctx.activation_dtype,
ctx.parallel_mode, ctx.parallel_mode,
ctx.tensor_parallel, ctx.tensor_parallel,
...@@ -315,9 +321,9 @@ class _LayerNormLinear(paddle.autograd.PyLayer): ...@@ -315,9 +321,9 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
return ( return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if not ln_weight.stop_gradient else None, dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None, dbeta if ctx.requires_ln_bgrad else None,
wgrad if not weight.stop_gradient else None, wgrad if ctx.requires_wgrad else None,
*bgrad_out, *bgrad_out,
) )
......
...@@ -35,6 +35,8 @@ from ..utils import ( ...@@ -35,6 +35,8 @@ from ..utils import (
cast_if_needed_inplace, cast_if_needed_inplace,
divide, divide,
get_paddle_act_func, get_paddle_act_func,
save_for_backward_allow_none,
saved_tensor_allow_none,
) )
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -147,6 +149,7 @@ def _mlp_backward( ...@@ -147,6 +149,7 @@ def _mlp_backward(
fc1_weight_t_fp8: paddle.Tensor, fc1_weight_t_fp8: paddle.Tensor,
fc1_weight_fp8_index: FP8FwdTensors, fc1_weight_fp8_index: FP8FwdTensors,
fc1_grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT2 fc1_grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT2
requires_fc1_wgrad: bool,
requires_fc1_bgrad: bool, requires_fc1_bgrad: bool,
fc1_out: paddle.Tensor, fc1_out: paddle.Tensor,
fc2_input: paddle.Tensor, # gelu_out fc2_input: paddle.Tensor, # gelu_out
...@@ -154,6 +157,7 @@ def _mlp_backward( ...@@ -154,6 +157,7 @@ def _mlp_backward(
fc2_weight: paddle.Tensor, fc2_weight: paddle.Tensor,
fc2_weight_t_fp8: paddle.Tensor, fc2_weight_t_fp8: paddle.Tensor,
fc2_weight_fp8_index: FP8FwdTensors, fc2_weight_fp8_index: FP8FwdTensors,
requires_fc2_wgrad: bool,
requires_fc2_bgrad: bool, requires_fc2_bgrad: bool,
grad_output: paddle.Tensor, grad_output: paddle.Tensor,
grad_output_c: paddle.Tensor, grad_output_c: paddle.Tensor,
...@@ -183,7 +187,6 @@ def _mlp_backward( ...@@ -183,7 +187,6 @@ def _mlp_backward(
# FC2 Bwd # FC2 Bwd
fc2_input_no_fp8, fc2_input_t = None, None fc2_input_no_fp8, fc2_input_t = None, None
fp8_wgrad = not fp8_meta["recipe"].override_linear_precision.wgrad fp8_wgrad = not fp8_meta["recipe"].override_linear_precision.wgrad
requires_fc2_wgrad = not fc2_weight.stop_gradient
if requires_fc2_wgrad: if requires_fc2_wgrad:
if fp8_wgrad: if fp8_wgrad:
fc2_input_t = transpose(fc2_input, fp8_dtype_forward) fc2_input_t = transpose(fc2_input, fp8_dtype_forward)
...@@ -229,7 +232,6 @@ def _mlp_backward( ...@@ -229,7 +232,6 @@ def _mlp_backward(
fc1_bgrad = fc1_bgrad_ fc1_bgrad = fc1_bgrad_
# FC1 Bwd # FC1 Bwd
requires_fc1_wgrad = not fc1_weight.stop_gradient
dgelu_no_fp8, fc1_input_no_fp8, fc1_input_t = None, None, None dgelu_no_fp8, fc1_input_no_fp8, fc1_input_t = None, None, None
if requires_fc1_wgrad: if requires_fc1_wgrad:
if fp8_wgrad: if fp8_wgrad:
...@@ -277,6 +279,7 @@ def _mlp_backward( ...@@ -277,6 +279,7 @@ def _mlp_backward(
grad_output, grad_output,
requires_fc2_bgrad, requires_fc2_bgrad,
True, True,
requires_fc2_wgrad,
activation_dtype, activation_dtype,
'row' if set_parallel_mode else None, 'row' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
...@@ -290,6 +293,7 @@ def _mlp_backward( ...@@ -290,6 +293,7 @@ def _mlp_backward(
dgelu, dgelu,
requires_fc1_bgrad, requires_fc1_bgrad,
requires_dgrad, requires_dgrad,
requires_fc1_wgrad,
activation_dtype, activation_dtype,
'column' if set_parallel_mode else None, 'column' if set_parallel_mode else None,
tensor_parallel, tensor_parallel,
...@@ -397,7 +401,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -397,7 +401,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
) )
if is_grad_enabled: if is_grad_enabled:
ctx.save_for_backward( save_for_backward_allow_none(
ctx,
inputmat, inputmat,
ln_weight, ln_weight,
mu, mu,
...@@ -426,9 +431,12 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -426,9 +431,12 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_fc1_wgrad = not fc1_weight.stop_gradient
ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient
ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient
ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient ctx.requires_ln_bgrad = not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1])) fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1]))
...@@ -446,7 +454,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -446,7 +454,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.tp_group, ctx.tp_group,
ctx.tp_size, ctx.tp_size,
name="_LayerNormMLP"): name="_LayerNormMLP"):
( ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
ln_weight, ln_weight,
mu, mu,
...@@ -459,7 +467,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -459,7 +467,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc2_weight, fc2_weight,
fc2_weight_t_fp8, fc2_weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensor() ) = saved_tensor_allow_none(ctx)
ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
( (
...@@ -482,6 +490,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -482,6 +490,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc1_weight_t_fp8, fc1_weight_t_fp8,
FP8FwdTensors.GEMM1_WEIGHT, FP8FwdTensors.GEMM1_WEIGHT,
FP8BwdTensors.GRAD_OUTPUT2, FP8BwdTensors.GRAD_OUTPUT2,
ctx.requires_fc1_wgrad,
ctx.requires_fc1_bgrad, ctx.requires_fc1_bgrad,
fc1_out, fc1_out,
gelu_out, gelu_out,
...@@ -489,6 +498,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -489,6 +498,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc2_weight, fc2_weight,
fc2_weight_t_fp8, fc2_weight_t_fp8,
FP8FwdTensors.GEMM2_WEIGHT, FP8FwdTensors.GEMM2_WEIGHT,
ctx.requires_fc2_wgrad,
ctx.requires_fc2_bgrad, ctx.requires_fc2_bgrad,
grad_output, grad_output,
grad_output_c, grad_output_c,
...@@ -528,11 +538,11 @@ class _LayerNormMLP(paddle.autograd.PyLayer): ...@@ -528,11 +538,11 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
return ( return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if not ln_weight.stop_gradient else None, dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None, dbeta if ctx.requires_ln_bgrad else None,
fc1_wgrad if not fc1_weight.stop_gradient else None, fc1_wgrad if ctx.requires_fc1_wgrad else None,
*fc1_bgrad_out, *fc1_bgrad_out,
fc2_wgrad if not fc2_weight.stop_gradient else None, fc2_wgrad if ctx.requires_fc2_wgrad else None,
*fc2_bgrad_out, *fc2_bgrad_out,
) )
......
...@@ -34,6 +34,8 @@ from ..utils import ( ...@@ -34,6 +34,8 @@ from ..utils import (
cast_if_needed_inplace, cast_if_needed_inplace,
divide, divide,
get_bias_dtype, get_bias_dtype,
save_for_backward_allow_none,
saved_tensor_allow_none,
) )
__all__ = ["Linear", "_linear_fwd", "_linear_fwd_fp8", "_linear_bwd", "_linear_fwd_non_fp8"] __all__ = ["Linear", "_linear_fwd", "_linear_fwd_fp8", "_linear_bwd", "_linear_fwd_non_fp8"]
...@@ -272,6 +274,7 @@ def _linear_bwd_non_fp8( ...@@ -272,6 +274,7 @@ def _linear_bwd_non_fp8(
grad_output: paddle.Tensor, grad_output: paddle.Tensor,
requires_bgrad: bool, requires_bgrad: bool,
requires_dgrad: bool, requires_dgrad: bool,
requires_wgrad: bool,
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
...@@ -283,7 +286,6 @@ def _linear_bwd_non_fp8( ...@@ -283,7 +286,6 @@ def _linear_bwd_non_fp8(
Performs Linear Backward. Optionally, fuses GELU backward and dbias. Performs Linear Backward. Optionally, fuses GELU backward and dbias.
""" """
dgrad, wgrad, bgrad = None, None, None dgrad, wgrad, bgrad = None, None, None
requires_wgrad = not weight.stop_gradient
if requires_dgrad: if requires_dgrad:
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
weight, weight,
...@@ -330,13 +332,13 @@ def _linear_bwd( ...@@ -330,13 +332,13 @@ def _linear_bwd(
fp8_enabled: bool, fp8_enabled: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
requires_dgrad: bool, requires_dgrad: bool,
requires_wgrad: bool,
activation_dtype: paddle.dtype, activation_dtype: paddle.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
tensor_parallel: bool, tensor_parallel: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
): ):
dgrad, wgrad, bgrad = None, None, None dgrad, wgrad, bgrad = None, None, None
requires_wgrad = not weight.stop_gradient
if fp8_enabled: if fp8_enabled:
dgrad, wgrad = _linear_bwd_fp8( dgrad, wgrad = _linear_bwd_fp8(
inputmat, inputmat,
...@@ -364,6 +366,7 @@ def _linear_bwd( ...@@ -364,6 +366,7 @@ def _linear_bwd(
grad_output, grad_output,
requires_bgrad, requires_bgrad,
requires_dgrad, requires_dgrad,
requires_wgrad,
activation_dtype, activation_dtype,
parallel_mode, parallel_mode,
tensor_parallel, tensor_parallel,
...@@ -449,7 +452,8 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -449,7 +452,8 @@ class _Linear(paddle.autograd.PyLayer):
if is_grad_enabled: if is_grad_enabled:
fp8_wgrad = fp8_enabled and not fp8_meta["recipe"].override_linear_precision.wgrad fp8_wgrad = fp8_enabled and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward( save_for_backward_allow_none(
ctx,
inputmat_no_fp8 if not weight.stop_gradient and not fp8_wgrad else None, inputmat_no_fp8 if not weight.stop_gradient and not fp8_wgrad else None,
inputmat_t if not weight.stop_gradient and fp8_wgrad else None, inputmat_t if not weight.stop_gradient and fp8_wgrad else None,
weight, weight,
...@@ -466,6 +470,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -466,6 +470,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient ctx.requires_bgrad = use_bias and not bias.stop_gradient
return out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
...@@ -477,13 +482,14 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -477,13 +482,14 @@ class _Linear(paddle.autograd.PyLayer):
ctx.tp_group, ctx.tp_group,
ctx.tp_size, ctx.tp_size,
name="_Linear"): name="_Linear"):
(
( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
inputmat_t, inputmat_t,
weight, weight,
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensor() ) = saved_tensor_allow_none(ctx)
( (
grad_output, grad_output,
...@@ -508,6 +514,7 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -508,6 +514,7 @@ class _Linear(paddle.autograd.PyLayer):
ctx.fp8_enabled, ctx.fp8_enabled,
ctx.fp8_meta, ctx.fp8_meta,
ctx.requires_dgrad, ctx.requires_dgrad,
ctx.requires_wgrad,
ctx.activation_dtype, ctx.activation_dtype,
ctx.parallel_mode, ctx.parallel_mode,
ctx.tensor_parallel, ctx.tensor_parallel,
...@@ -520,12 +527,12 @@ class _Linear(paddle.autograd.PyLayer): ...@@ -520,12 +527,12 @@ class _Linear(paddle.autograd.PyLayer):
if not ctx.use_bias: if not ctx.use_bias:
return ( return (
wgrad if not weight.stop_gradient else None, wgrad if ctx.requires_wgrad else None,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
) )
return ( return (
wgrad if not weight.stop_gradient else None, wgrad if ctx.requires_wgrad else None,
dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
bgrad if ctx.requires_bgrad else None, bgrad if ctx.requires_bgrad else None,
) )
......
...@@ -70,7 +70,17 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -70,7 +70,17 @@ class TransformerLayer(paddle.nn.Layer):
`here <https://arxiv.org/pdf/1909.08053.pdf>`_. `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = `None`
tensor parallel process group. tensor parallel process group.
attention_dropout_rng_state_name : str, default = `local_seed`
Controls the rng state used for dropout on attention probs. The
specified rng should be set different seeds for different TP ranks.
It will be ignored if `set_parallel_mode` is False.
hidden_dropout_rng_state_name : str, default = `global_seed`
Controls the rng state used for dropout on hidden states. The
specified rng should be given the same seeds for different TP
ranks. It will be ignored if `set_parallel_mode` is False. The
specified name should be registered through
`paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
.add(rng_state_name, seed)`.
""" """
def __init__(self, def __init__(self,
...@@ -181,6 +191,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -181,6 +191,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None, core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True, set_zero: bool = True,
recompute_core_attention: bool = False,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Transformer Layer: attention block and a feedforward network (MLP) Transformer Layer: attention block and a feedforward network (MLP)
...@@ -207,6 +218,11 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -207,6 +218,11 @@ class TransformerLayer(paddle.nn.Layer):
Bias tensor for Q * K.T Bias tensor for Q * K.T
set_zero: bool, default = `True` set_zero: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
recompute_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
""" """
if self.self_attn_mask_type != "causal" and attention_mask is not None: if self.self_attn_mask_type != "causal" and attention_mask is not None:
...@@ -222,6 +238,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -222,6 +238,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
set_zero=set_zero, set_zero=set_zero,
recompute_core_attention=recompute_core_attention,
) )
if self.apply_residual_connection_post_layernorm and not self.output_layernorm: if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
...@@ -248,6 +265,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -248,6 +265,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
set_zero=set_zero, set_zero=set_zero,
recompute_core_attention=recompute_core_attention,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
attention_output, residual = inter_attention_outputs attention_output, residual = inter_attention_outputs
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Methods needed for recompute."""
import os
import inspect
from paddle.distributed import fleet
from .constants import RecomputeFunctionNames
from .fp8 import get_global_fp8_state
__all__ = ['recompute', 'is_in_recompute_phase']
_DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0"))
def is_in_recompute_phase():
"""Inspect call stack to determine if this is called from
backward phase. Paddle has two recompute methods:
(1) Use RecomputeFunction. The recomputed function is called from `RecomputeFunction.backward`;
(2) Use paddle.autograd.saved_tensors_hooks. The recompute function is called from `unpack`."""
if _DISABLE_RECOMPUTE:
return False
frame = inspect.currentframe().f_back
while frame:
if frame.f_code.co_name in RecomputeFunctionNames:
return True
frame = frame.f_back
return False
def recompute(function, *args, **kwargs):
"""
This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary
state information for fp8 layers.
"""
assert not _DISABLE_RECOMPUTE, "Recompute is disabled. " \
f"Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}."
global_fp8_state = get_global_fp8_state()
try:
global_fp8_state._fp8_recompute_enabled = True
outputs = fleet.utils.recompute(function, *args, **kwargs)
finally:
global_fp8_state._fp8_recompute_enabled = False
return outputs
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Utility functions for Transformer Engine modules""" """Utility functions for Transformer Engine modules"""
from typing import Union from typing import Optional, Tuple, Union
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -86,3 +86,38 @@ def divide(numerator: int, denominator: int) -> int: ...@@ -86,3 +86,38 @@ def divide(numerator: int, denominator: int) -> int:
the division value.""" the division value."""
assert (numerator % denominator == 0), f"{numerator} is not divisible by {denominator}" assert (numerator % denominator == 0), f"{numerator} is not divisible by {denominator}"
return numerator // denominator return numerator // denominator
def save_for_backward_allow_none(ctx, *args) -> None:
"""Save tensors for backward. Args could be None"""
indices_mapping = []
tensors_to_save = []
for x in args:
if isinstance(x, paddle.Tensor):
indices_mapping.append(len(tensors_to_save))
tensors_to_save.append(x)
elif x is None:
indices_mapping.append(-1)
else:
raise ValueError(f"Type {type(x)} is not allowed.")
ctx._indices_mapping = indices_mapping
ctx.save_for_backward(*tensors_to_save)
def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]:
"""Used with `save_for_backward_allow_none` in pair. Get saved tensors from ctx."""
assert hasattr(ctx, '_indices_mapping'), "`saved_tensor_allow_none` must be used " \
"with `save_for_backward_allow_none` in pair."
indices_mapping = ctx._indices_mapping
outputs = []
saved_tensors = ctx.saved_tensor()
for index in indices_mapping:
if index < 0:
outputs.append(None)
else:
outputs.append(saved_tensors[index])
return tuple(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment