Unverified Commit e1edaaec authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Reduce CPU overheads (#2377)



Initial changes to remove pytorch overheads
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 42d22740
...@@ -2489,7 +2489,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2489,7 +2489,6 @@ class _custom_mha_fp8(torch.autograd.Function):
max_s: int, max_s: int,
fast_zero_fill: bool, fast_zero_fill: bool,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool, is_training: bool,
mask_type: str, mask_type: str,
quantizers: list[Quantizer], quantizers: list[Quantizer],
...@@ -2518,7 +2517,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2518,7 +2517,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv, *_ = ext.general_gemm( qkv, *_ = ext.general_gemm(
qkv_weight_fp8, qkv_weight_fp8,
inp_fp8, inp_fp8,
workspace,
bias=qkv_bias, bias=qkv_bias,
out_dtype=qkv_weight_fp8.dtype, out_dtype=qkv_weight_fp8.dtype,
quantization_params=qkv_quantizer, quantization_params=qkv_quantizer,
...@@ -2560,9 +2558,7 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2560,9 +2558,7 @@ class _custom_mha_fp8(torch.autograd.Function):
s_quantizer=s_quantizer, s_quantizer=s_quantizer,
) )
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
)
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
...@@ -2592,7 +2588,7 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2592,7 +2588,7 @@ class _custom_mha_fp8(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"): with torch.cuda.nvtx.range("_DPA"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
(q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved( (q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
ctx.tensor_objects, saved_tensors ctx.tensor_objects, saved_tensors
) )
...@@ -2648,7 +2644,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2648,7 +2644,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_dgrad, *_ = ext.general_gemm( qkv_dgrad, *_ = ext.general_gemm(
qkv_weight_fp8, qkv_weight_fp8,
dqkv_c, dqkv_c,
workspace,
ctx.dtype, ctx.dtype,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
layout="NN", layout="NN",
...@@ -2658,7 +2653,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2658,7 +2653,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_wgrad, *_ = ext.general_gemm( qkv_wgrad, *_ = ext.general_gemm(
inp_fp8, inp_fp8,
dqkv, dqkv,
workspace,
ctx.dtype, ctx.dtype,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
layout="NT", layout="NT",
...@@ -2709,9 +2703,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -2709,9 +2703,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
with torch.no_grad(): with torch.no_grad():
self.qkv_bias.zero_() self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0) self.qkv_weight.fill_(1.0)
self.workspace = torch.empty(
_CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
)
def forward( def forward(
self, self,
...@@ -2730,7 +2721,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule): ...@@ -2730,7 +2721,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
max_s, max_s,
self.fast_zero_fill, self.fast_zero_fill,
self.fp8_meta, self.fp8_meta,
self.workspace,
self.training, self.training,
self.mask_type, self.mask_type,
self.quantizers, self.quantizers,
......
...@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split ...@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split
out, *_ = tepytorch.cpp_extensions.general_gemm( out, *_ = tepytorch.cpp_extensions.general_gemm(
fp8_tensor1, fp8_tensor1,
fp8_tensor2, fp8_tensor2,
tepytorch.module.base.get_workspace(),
torch.float32, torch.float32,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
) )
...@@ -199,7 +198,6 @@ def _emulate_linear( ...@@ -199,7 +198,6 @@ def _emulate_linear(
wgrad, *_ = tepytorch.cpp_extensions.general_gemm( wgrad, *_ = tepytorch.cpp_extensions.general_gemm(
wgrad_input, wgrad_input,
wgrad_gradient, wgrad_gradient,
tepytorch.module.base.get_workspace(),
torch.float32, torch.float32,
layout="NT", layout="NT",
grad=True, grad=True,
......
...@@ -7,7 +7,7 @@ import sys ...@@ -7,7 +7,7 @@ import sys
import pytest import pytest
import torch import torch
import transformer_engine import transformer_engine
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear, GroupedLinear
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
...@@ -19,7 +19,9 @@ model_configs = { ...@@ -19,7 +19,9 @@ model_configs = {
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"]) @pytest.mark.parametrize(
"module", ["TransformerLayer", "DotProductAttention", "Linear", "GroupedLinear"]
)
def test_current_device(model, module): def test_current_device(model, module):
"""Test cases where current device is different from tensor device""" """Test cases where current device is different from tensor device"""
...@@ -58,7 +60,7 @@ def test_current_device(model, module): ...@@ -58,7 +60,7 @@ def test_current_device(model, module):
kwargs["cu_seqlens_kv"] = cu_seqlens_kv kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention": elif module == "DotProductAttention":
model = DotProductAttention( model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding" config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
) )
...@@ -97,6 +99,24 @@ def test_current_device(model, module): ...@@ -97,6 +99,24 @@ def test_current_device(model, module):
requires_grad=True, requires_grad=True,
) )
] ]
elif module == "GroupedLinear":
num_gemms = 4
model = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q * config.batch_size * (num_gemms - 1), config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
),
[0] + [config.max_seqlen_q * config.batch_size] * (num_gemms - 1), # Empty first split.
]
current_device_before = torch.cuda.current_device() current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs) out = model(*args, **kwargs)
......
...@@ -44,7 +44,6 @@ from transformer_engine.pytorch import ( ...@@ -44,7 +44,6 @@ from transformer_engine.pytorch import (
) )
from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states from utils import ModelConfig, reset_rng_states
...@@ -2690,7 +2689,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): ...@@ -2690,7 +2689,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm( general_gemm(
A[i], A[i],
B[i], B[i],
get_workspace(),
dtype, dtype,
grad=grad, grad=grad,
accumulate=accumulate, accumulate=accumulate,
...@@ -2705,7 +2703,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): ...@@ -2705,7 +2703,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B, B,
out, out,
dtype, dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits, m_splits=m_splits,
grad=grad, grad=grad,
accumulate=accumulate, accumulate=accumulate,
...@@ -2760,7 +2757,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua ...@@ -2760,7 +2757,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out, *_ = general_gemm( quantized_out, *_ = general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
quantization_params=out_quantizer, quantization_params=out_quantizer,
bias=None, bias=None,
...@@ -2770,7 +2766,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua ...@@ -2770,7 +2766,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out, *_ = general_gemm( out, *_ = general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
quantization_params=None, quantization_params=None,
bias=None, bias=None,
...@@ -2846,7 +2841,6 @@ def test_fp8_grouped_gemm(shape, accumulate): ...@@ -2846,7 +2841,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm( general_gemm(
A_fp8[i], A_fp8[i],
B_fp8[i], B_fp8[i],
get_workspace(),
dtype, dtype,
out=out_ref[i], out=out_ref[i],
accumulate=accumulate, accumulate=accumulate,
...@@ -2856,7 +2850,6 @@ def test_fp8_grouped_gemm(shape, accumulate): ...@@ -2856,7 +2850,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
B_fp8, B_fp8,
out, out,
dtype, dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits, m_splits=m_splits,
accumulate=accumulate, accumulate=accumulate,
) )
......
...@@ -36,7 +36,6 @@ from transformer_engine.pytorch import ( ...@@ -36,7 +36,6 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from utils import ModelConfig from utils import ModelConfig
...@@ -912,7 +911,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): ...@@ -912,7 +911,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp = torch.reshape(scratchpad[offset:-offset], (N, N)) inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset * 2 :], (N, N)) weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_ = general_gemm(A=weight, B=inp, workspace=get_workspace()) _ = general_gemm(A=weight, B=inp)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -936,7 +935,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ...@@ -936,7 +935,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm( general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
bias=None, bias=None,
use_split_accumulator=False, use_split_accumulator=False,
......
...@@ -19,7 +19,12 @@ from transformer_engine.pytorch.utils import ( ...@@ -19,7 +19,12 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability, get_device_compute_capability,
split_tensor_along_dim, split_tensor_along_dim,
) )
from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop from transformer_engine.pytorch.utils import (
attention_mask_func,
nvtx_range_push,
nvtx_range_pop,
get_nvtx_range_context,
)
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
...@@ -1445,7 +1450,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1445,7 +1450,7 @@ class FusedAttnFunc(torch.autograd.Function):
dk = dk[..., : d_out.shape[-1]] dk = dk[..., : d_out.shape[-1]]
dv = dv[..., : d_out.shape[-1]] dv = dv[..., : d_out.shape[-1]]
else: else:
with torch.cuda.nvtx.range("FusedAttnFunc.backward"): with get_nvtx_range_context("FusedAttnFunc.backward"):
# get nominal data type of dq, dk, dv # get nominal data type of dq, dk, dv
# FP16/BF16 attention: torch.float16 or torch.bfloat16 # FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16 # FP8 attention: torch.float16 or torch.bfloat16
......
...@@ -975,7 +975,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -975,7 +975,7 @@ class DotProductAttention(TransformerEngineBaseModule):
Whether to enforce output to be in FP8 or not. Whether to enforce output to be in FP8 or not.
""" """
with torch.cuda.device(query_layer.device), self.prepare_forward( with self.prepare_forward(
query_layer, query_layer,
num_gemms=3, num_gemms=3,
allow_non_contiguous=True, allow_non_contiguous=True,
......
...@@ -6,23 +6,59 @@ ...@@ -6,23 +6,59 @@
from typing import Iterable, Optional, Tuple, Union, List from typing import Iterable, Optional, Tuple, Union, List
import os import os
import functools
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer, QuantizedTensor, QuantizedTensorStorage
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_custom from ..tensor.utils import is_custom
from ..custom_recipes.gemm import custom_gemm from ..custom_recipes.gemm import custom_gemm
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
"general_grouped_gemm", "general_grouped_gemm",
] ]
_NUM_MAX_UB_STREAMS = 3
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return 32 * 1024 * 1024 + 1024
return 4_194_304
@functools.lru_cache(maxsize=None)
def get_cublas_workspace(device: int, ub: bool, grouped_gemm: bool) -> torch.Tensor:
"""Returns workspace for cublas GEMM."""
assert not (ub and grouped_gemm), "UB is unsupported for grouped GEMM."
if ub:
return torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device
).repeat(_NUM_MAX_UB_STREAMS)
if grouped_gemm:
_multi_stream_cublas_workspace = []
for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
)
return _multi_stream_cublas_workspace
return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
def validate_gemm_scale(scale: Optional[float], required: bool) -> float: def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
"""Validate whether a GEMM scaling factor is consistent with its usage""" """Validate whether a GEMM scaling factor is consistent with its usage"""
if required: if required:
...@@ -32,10 +68,35 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: ...@@ -32,10 +68,35 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
return 0.0 return 0.0
def get_tensor_device(tensor: torch.Tensor) -> int:
"""Returns tensor device as an integer"""
if not isinstance(tensor, QuantizedTensorStorage):
return tensor.device.index
if isinstance(tensor, QuantizedTensor):
return tensor.device.index
if isinstance(tensor, (Float8BlockwiseQTensorStorage, MXFP8TensorStorage, NVFP4TensorStorage)):
return (
tensor._rowwise_data.device.index
if tensor._rowwise_data is not None
else tensor._columnwise_data.device.index
)
if isinstance(tensor, Float8TensorStorage):
return (
tensor._data.device.index
if tensor._data is not None
else tensor._transpose.device.index
)
try:
return (
tensor._data.device.index if tensor._data is not None else tensor._data_t.device.index
)
except AttributeError:
return torch.cuda.current_device()
def general_gemm( def general_gemm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
workspace: torch.Tensor,
out_dtype: Optional[torch.dtype] = None, out_dtype: Optional[torch.dtype] = None,
quantization_params: Optional[Quantizer] = None, quantization_params: Optional[Quantizer] = None,
gelu: bool = False, gelu: bool = False,
...@@ -62,6 +123,7 @@ def general_gemm( ...@@ -62,6 +123,7 @@ def general_gemm(
alpha = validate_gemm_scale(alpha, True) alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate) beta = validate_gemm_scale(beta, accumulate)
workspace = get_cublas_workspace(get_tensor_device(A), ub is not None, False)
if ub_type is not None: if ub_type is not None:
assert ub is not None, ( assert ub is not None, (
...@@ -159,7 +221,6 @@ def general_grouped_gemm( ...@@ -159,7 +221,6 @@ def general_grouped_gemm(
B: List[torch.Tensor], B: List[torch.Tensor],
out: List[torch.Tensor], out: List[torch.Tensor],
out_dtype: torch.dtype, out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
layout: str = "TN", layout: str = "TN",
m_splits: Optional[List[int]] = None, m_splits: Optional[List[int]] = None,
gelu: bool = False, gelu: bool = False,
...@@ -187,6 +248,8 @@ def general_grouped_gemm( ...@@ -187,6 +248,8 @@ def general_grouped_gemm(
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count() sm_count = get_sm_count()
workspaces = get_cublas_workspace(get_tensor_device(A[0]), False, True)
if grad and use_bias: if grad and use_bias:
grad_bias = [ grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms) torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
......
...@@ -108,6 +108,11 @@ std::vector<py::object> fused_attn_fwd( ...@@ -108,6 +108,11 @@ std::vector<py::object> fused_attn_fwd(
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen, const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) {
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(cu_seqlens_q.device());
auto none = py::none(); auto none = py::none();
// create QKV tensor wrappers // create QKV tensor wrappers
......
...@@ -95,6 +95,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -95,6 +95,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
bool bulk_overlap, float alpha, std::optional<float> beta) { bool bulk_overlap, float alpha, std::optional<float> beta) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace.device());
// Input tensors // Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
...@@ -351,6 +356,11 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, ...@@ -351,6 +356,11 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
at::Tensor workspace, size_t workspaceSize, bool accumulate, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter) { bool gemm_producer, at::Tensor counter) {
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace.device());
// TODO: Handle scaling modes // TODO: Handle scaling modes
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING;
...@@ -400,6 +410,11 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -400,6 +410,11 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
NVTE_ERROR("not implemented, D should be allocated for single output case."); NVTE_ERROR("not implemented, D should be allocated for single output case.");
} }
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace[0].device());
void* output_data_ptr = nullptr; void* output_data_ptr = nullptr;
if (single_output) { if (single_output) {
output_data_ptr = (*D)[0].data_ptr(); output_data_ptr = (*D)[0].data_ptr();
......
...@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
...@@ -294,6 +299,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -294,6 +299,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
const int sm_margin, const bool zero_centered_gamma) { const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail; using namespace transformer_engine::pytorch::detail;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
// Input and param tensors // Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
......
...@@ -39,13 +39,18 @@ from ..distributed import ( ...@@ -39,13 +39,18 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..cpp_extensions.gemm import _NUM_MAX_UB_STREAMS
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..utils import (
is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype,
get_nvtx_range_context,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -57,11 +62,8 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] ...@@ -57,11 +62,8 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP = False _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_dummy_wgrads = {} _dummy_wgrads = {}
_cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -75,35 +77,6 @@ class UserBufferQuantizationMode(Enum): ...@@ -75,35 +77,6 @@ class UserBufferQuantizationMode(Enum):
FP8 = "fp8" FP8 = "fp8"
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return 32 * 1024 * 1024 + 1024
return 4_194_304
def get_workspace() -> torch.Tensor:
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
)
return _cublas_workspace
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
return _multi_stream_cublas_workspace
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
"""Returns a dummy tensor of given shape.""" """Returns a dummy tensor of given shape."""
assert len(shape) == 2 assert len(shape) == 2
...@@ -276,16 +249,6 @@ def initialize_ub( ...@@ -276,16 +249,6 @@ def initialize_ub(
flush=True, flush=True,
) )
# Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS:
# This ensures we don't do `.repeat()` on an already expanded workspace
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
).repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
layers_all_gather_overlap = [ layers_all_gather_overlap = [
"qkv_fprop", "qkv_fprop",
...@@ -1078,8 +1041,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1078,8 +1041,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
self.allow_different_data_and_param_types = allow_different_data_and_param_types self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
delayed_scaling_recipe = self.fp8_meta["recipe"].delayed()
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else: else:
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -1091,25 +1056,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1091,25 +1056,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.init_fp8_metadata(num_gemms=num_gemms) self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence() self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe:
if self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, ( assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is " "Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8." "necessary when using sequence parallelism with FP8."
) )
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): if not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
# Activation recomputation is used and this is the first forward phase. # Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous(): if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous() inp = inp.contiguous()
yield inp yield inp
if self.fp8 and in_fp8_activation_recompute_phase(): if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
def set_nccl_overlap_warning_if_tp(self) -> None: def set_nccl_overlap_warning_if_tp(self) -> None:
...@@ -1531,7 +1498,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1531,7 +1498,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with get_nvtx_range_context(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
weight_tensor = noop_cat(self._get_weight_tensors()) weight_tensor = noop_cat(self._get_weight_tensors())
...@@ -1628,6 +1595,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1628,6 +1595,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
if not self.fp8 and not self.fp8_calibration: if not self.fp8 and not self.fp8_calibration:
return return
if not self.primary_weights_in_fp8:
return
if not hasattr(self, "weight_names") or not self.weight_names: if not hasattr(self, "weight_names") or not self.weight_names:
return return
......
...@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = inp.shape[-1] in_features = inp.shape[-1]
...@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function): ...@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function):
grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits
) )
return (grad_input, None, None, None) return grad_input, None
class Fp8Padding(torch.nn.Module): class Fp8Padding(torch.nn.Module):
...@@ -128,19 +131,20 @@ class Fp8Padding(torch.nn.Module): ...@@ -128,19 +131,20 @@ class Fp8Padding(torch.nn.Module):
if m_splits == padded_m_splits: if m_splits == padded_m_splits:
return inp, m_splits return inp, m_splits
if torch.is_grad_enabled(): is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
fn = _Fp8Padding.apply fn = _Fp8Padding.apply
args = [] autograd_ctx = []
else: else:
fn = _Fp8Padding.forward fn = _Fp8Padding.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
m_splits, m_splits,
padded_m_splits, padded_m_splits,
torch.is_grad_enabled(), is_grad_enabled,
) )
out = fn(*args) out = fn(*autograd_ctx, inp, non_tensor_args)
return out, padded_m_splits return out, padded_m_splits
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""FP8 Padding API""" """FP8 Padding API"""
from typing import List, Optional from typing import List, Optional, Tuple
import torch import torch
...@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
padded_m_splits: List[int],
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args
in_features = inp.shape[-1] in_features = inp.shape[-1]
# Allocate cast and transpose output tensor # Allocate cast and transpose output tensor
...@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function): ...@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function):
grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits
) )
return (grad_input, None, None, None) return grad_input, None
class Fp8Unpadding(torch.nn.Module): class Fp8Unpadding(torch.nn.Module):
...@@ -126,19 +129,20 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -126,19 +129,20 @@ class Fp8Unpadding(torch.nn.Module):
if m_splits == padded_m_splits: if m_splits == padded_m_splits:
return inp return inp
if torch.is_grad_enabled(): is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
fn = _Fp8Unpadding.apply fn = _Fp8Unpadding.apply
args = [] autograd_ctx = []
else: else:
fn = _Fp8Unpadding.forward fn = _Fp8Unpadding.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
m_splits, m_splits,
padded_m_splits, padded_m_splits,
torch.is_grad_enabled(), is_grad_enabled,
) )
out = fn(*args) out = fn(*autograd_ctx, inp, non_tensor_args)
return out return out
...@@ -14,7 +14,6 @@ import transformer_engine_torch as tex ...@@ -14,7 +14,6 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_dummy_wgrad, get_dummy_wgrad,
get_multi_stream_cublas_workspace,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
...@@ -28,6 +27,7 @@ from ..utils import ( ...@@ -28,6 +27,7 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
init_method_constant, init_method_constant,
requires_grad, requires_grad,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -40,7 +40,6 @@ from ..cpp_extensions import ( ...@@ -40,7 +40,6 @@ from ..cpp_extensions import (
) )
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
...@@ -63,28 +62,34 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -63,28 +62,34 @@ class _GroupedLinear(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
m_splits: List[int], non_tensor_args: Tuple,
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
grad_output_quantizers: List[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
sequence_parallel: bool,
activation_dtype: torch.dtype,
is_grad_enabled: bool,
module,
skip_fp8_weight_update,
save_original_input,
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits,
use_bias,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_output_quantizers,
fuse_wgrad_accumulation,
cpu_offloading,
sequence_parallel,
activation_dtype,
is_grad_enabled,
module,
skip_fp8_weight_update,
save_original_input,
) = non_tensor_args
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
...@@ -183,7 +188,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -183,7 +188,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmats, inputmats,
[out], [out],
activation_dtype, activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True, single_output=True,
m_splits=m_splits, m_splits=m_splits,
bias=biases, bias=biases,
...@@ -284,7 +288,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -284,7 +288,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_GroupedLinear_backward"): with get_nvtx_range_context("_GroupedLinear_backward"):
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
N = ctx.num_gemms N = ctx.num_gemms
inputmats = saved_tensors[:N] inputmats = saved_tensors[:N]
...@@ -372,7 +376,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -372,7 +376,6 @@ class _GroupedLinear(torch.autograd.Function):
grad_output, grad_output,
[dgrad], [dgrad],
ctx.activation_dtype, ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True, single_output=True,
layout="NN", layout="NN",
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
...@@ -419,7 +422,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -419,7 +422,6 @@ class _GroupedLinear(torch.autograd.Function):
grouped_gemm_wgrad = functools.partial( grouped_gemm_wgrad = functools.partial(
general_grouped_gemm, general_grouped_gemm,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
workspaces=get_multi_stream_cublas_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
...@@ -484,28 +486,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -484,28 +486,11 @@ class _GroupedLinear(torch.autograd.Function):
): ):
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return ( return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None, None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
*wgrad_list, *wgrad_list,
*grad_biases, *grad_biases,
) )
...@@ -765,16 +750,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -765,16 +750,9 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if FP8GlobalStateManager.fp8_graph_capturing(): is_grad_enabled = torch.is_grad_enabled()
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with torch.cuda.device( with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors() weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
...@@ -794,7 +772,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -794,7 +772,7 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme # TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms): for i in range(self.num_gemms):
input_quantizers[i].internal = False input_quantizers[i].internal = False
if torch.is_grad_enabled(): if is_grad_enabled:
grad_output_quantizers = [ grad_output_quantizers = [
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
...@@ -804,14 +782,14 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -804,14 +782,14 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms): for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True grad_output_quantizers[i].internal = True
if torch.is_grad_enabled(): if is_grad_enabled:
linear_fn = _GroupedLinear.apply linear_fn = _GroupedLinear.apply
args = [] autograd_ctx = []
else: else:
linear_fn = _GroupedLinear.forward linear_fn = _GroupedLinear.forward
args = [None] autograd_ctx = [None]
args += (
inp, non_tensor_args = (
m_splits, m_splits,
self.apply_bias, self.apply_bias,
is_first_microbatch, is_first_microbatch,
...@@ -826,14 +804,12 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -826,14 +804,12 @@ class GroupedLinear(TransformerEngineBaseModule):
is_cpu_offload_enabled(), is_cpu_offload_enabled(),
self.sequence_parallel, self.sequence_parallel,
self.activation_dtype, self.activation_dtype,
torch.is_grad_enabled(), is_grad_enabled,
self, self,
skip_fp8_weight_update, None, # skip_fp8_weight_update
self.save_original_input, self.save_original_input,
*weight_tensors,
*bias_tensors,
) )
out = linear_fn(*args) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
...@@ -846,7 +822,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -846,7 +822,7 @@ class GroupedLinear(TransformerEngineBaseModule):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): with get_nvtx_range_context("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop() (_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2] wgrad_list = tensor_list[2]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
......
...@@ -19,7 +19,6 @@ from transformer_engine.pytorch import torch_version ...@@ -19,7 +19,6 @@ from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_custom from transformer_engine.pytorch.tensor.utils import is_custom
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
get_dummy_wgrad, get_dummy_wgrad,
...@@ -40,6 +39,7 @@ from ..utils import ( ...@@ -40,6 +39,7 @@ from ..utils import (
nvtx_range_push, nvtx_range_push,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -96,47 +96,53 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -96,47 +96,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias: Union[torch.Tensor, None], ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
eps: float, non_tensor_args: Tuple,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
ub_overlap_ag_fprop: bool,
ub_overlap_rs_fprop: bool,
ub_overlap_ag_dgrad: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_name: str,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
parallel_mode,
return_layernorm_output,
return_layernorm_output_gathered,
is_grad_enabled,
fwd_ln_sm_margin,
bwd_ln_sm_margin,
zero_centered_gamma,
normalization,
ub_overlap_ag_fprop,
ub_overlap_rs_fprop,
ub_overlap_ag_dgrad,
ub_overlap_rs_dgrad,
ub_bulk_wgrad,
ub_bulk_dgrad,
ub_name,
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
debug,
) = non_tensor_args
# NVTX label for profiling # NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward" nvtx_label = "transformer_engine._LayerNormLinear.forward"
if ub_name is not None: if ub_name is not None:
...@@ -355,7 +361,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -355,7 +361,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
ln_out_total, ln_out_total,
get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
...@@ -544,7 +549,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -544,7 +549,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_name is not None: if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"): with get_nvtx_range_context("_LayerNormLinear_backward"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -731,7 +736,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -731,7 +736,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weight, weight,
grad_output, grad_output,
get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
...@@ -858,7 +862,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -858,7 +862,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = { wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
...@@ -1026,44 +1029,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1026,44 +1029,7 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta, dbeta,
wgrad, wgrad,
grad_bias, grad_bias,
None, # eps None,
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # normalization
None, # ub_overlap_ag_fprop
None, # ub_overlap_rs_fprop
None, # ub_overlap_ag_dgrad
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fsdp_group
None, # debug
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
) )
...@@ -1523,8 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1523,8 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output, is_grad_enabled)
debug = self.is_debug_iter() debug = self.is_debug_iter()
...@@ -1546,9 +1514,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1546,9 +1514,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
).is_fp8_ubuf(): ).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with torch.cuda.device( with self.prepare_forward(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp: ) as inp:
...@@ -1556,14 +1522,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1556,14 +1522,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = ( quantizers = (
self._get_quantizers(fp8_output, fp8_grad) self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad) quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
( (
input_quantizer, input_quantizer,
...@@ -1574,18 +1540,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1574,18 +1540,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) = quantizers ) = quantizers
if torch.is_grad_enabled(): if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
args = [] autograd_ctx = []
else: else:
fwd_fn = _LayerNormLinear.forward fwd_fn = _LayerNormLinear.forward
args = [None] autograd_ctx = [None]
args += ( non_tensor_args = (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
...@@ -1607,8 +1568,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1607,8 +1568,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.parallel_mode, self.parallel_mode,
self.return_layernorm_output, self.return_layernorm_output,
self.return_layernorm_output_gathered, self.return_layernorm_output_gathered,
torch.is_grad_enabled(), is_grad_enabled,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization, self.normalization,
...@@ -1625,7 +1586,15 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1625,7 +1586,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.symmetric_ar_type, self.symmetric_ar_type,
debug, debug,
) )
out = fwd_fn(*args) out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1641,7 +1610,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1641,7 +1610,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8: if not self.fp8:
return [None] * 6 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
...@@ -1653,7 +1622,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1653,7 +1622,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
(weight_quantizer,) = self._get_weight_quantizers() (weight_quantizer,) = self._get_weight_quantizers()
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if fp8_grad: if fp8_grad:
...@@ -1668,8 +1637,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1668,8 +1637,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) )
def _get_debug_quantizers(self, fp8_output, fp8_grad): def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad) original_quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
...@@ -1694,6 +1663,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1694,6 +1663,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
fp8_output: bool, fp8_output: bool,
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
ONNX-compatible version of the forward function that provides numerical equivalence ONNX-compatible version of the forward function that provides numerical equivalence
...@@ -1709,7 +1679,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1709,7 +1679,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
*_, *_,
) = self._get_quantizers(fp8_output, fp8_grad=False) ) = self._get_quantizers(fp8_output, False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
......
...@@ -20,7 +20,6 @@ from transformer_engine.pytorch import torch_version ...@@ -20,7 +20,6 @@ from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_custom from transformer_engine.pytorch.tensor.utils import is_custom
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_workspace,
_ub_communicators, _ub_communicators,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -45,6 +44,7 @@ from ..utils import ( ...@@ -45,6 +44,7 @@ from ..utils import (
clear_tensor_data, clear_tensor_data,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -174,55 +174,61 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -174,55 +174,61 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias: torch.Tensor, fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor, fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor, fc2_bias: torch.Tensor,
eps: float, non_tensor_args: Tuple,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer],
fc1_output_quantizer: Optional[Quantizer],
fc1_grad_input_quantizer: Optional[Quantizer],
fc1_grad_weight_quantizer: Optional[Quantizer],
fc1_grad_output_quantizer: Optional[Quantizer],
fc2_input_quantizer: Optional[Quantizer],
fc2_weight_quantizer: Optional[Quantizer],
fc2_output_quantizer: Optional[Quantizer],
fc2_grad_input_quantizer: Optional[Quantizer],
fc2_grad_weight_quantizer: Optional[Quantizer],
fc2_grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
bias_gelu_fusion: bool,
set_parallel_mode: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
activation_params: Optional[dict],
normalization: str,
ub_overlap_ag: bool,
ub_overlap_rs: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
gemm_gelu_fusion: bool,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
return_layernorm_output,
return_layernorm_output_gathered,
bias_gelu_fusion,
set_parallel_mode,
is_grad_enabled,
fwd_ln_sm_margin,
bwd_ln_sm_margin,
zero_centered_gamma,
activation,
activation_params,
normalization,
ub_overlap_ag,
ub_overlap_rs,
ub_overlap_rs_dgrad,
ub_bulk_wgrad,
ub_bulk_dgrad,
gemm_gelu_fusion,
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
debug,
) = non_tensor_args
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible" assert inp_shape[-1] == in_features, "GEMM not possible"
...@@ -433,7 +439,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -433,7 +439,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_outputs = general_gemm( fc1_outputs = general_gemm(
fc1_weight_final, fc1_weight_final,
ln_out_total, ln_out_total,
get_workspace(),
quantization_params=( quantization_params=(
fc2_input_quantizer fc2_input_quantizer
if gemm_gelu_fusion if gemm_gelu_fusion
...@@ -517,7 +522,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -517,7 +522,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final, fc2_weight_final,
act_out, act_out,
get_workspace(),
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=fc2_bias, bias=fc2_bias,
quantization_params=fc2_output_quantizer, quantization_params=fc2_output_quantizer,
...@@ -704,7 +708,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -704,7 +708,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"): with get_nvtx_range_context("_LayerNormMLP_backward"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -874,7 +878,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -874,7 +878,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_output, *_ = general_gemm( gemm_output, *_ = general_gemm(
fc2_weight, fc2_weight,
grad_output, grad_output,
get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=( quantization_params=(
...@@ -968,7 +971,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -968,7 +971,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs = { fc2_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
origin_fc2_weight.main_grad.dtype origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
...@@ -1138,7 +1140,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1138,7 +1140,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
fc1_weight, fc1_weight,
dact, dact,
get_workspace(),
out=gemm_out, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer, quantization_params=ctx.fc1_grad_input_quantizer,
...@@ -1217,7 +1218,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1217,7 +1218,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
fc1_wgrad_gemm_kwargs = { fc1_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
origin_fc1_weight.main_grad.dtype origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
...@@ -1399,52 +1399,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1399,52 +1399,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias_grad if fc1_bias is not None else None, fc1_bias_grad if fc1_bias is not None else None,
fc2_wgrad, # pylint: disable=possibly-used-before-assignment fc2_wgrad, # pylint: disable=possibly-used-before-assignment
fc2_bias_grad, fc2_bias_grad,
None, # eps None,
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer,
None, # fc1_weight_quantizer,
None, # fc1_output_quantizer,
None, # fc1_grad_input_quantizer,
None, # fc1_grad_weight_quantizer,
None, # fc1_grad_output_quantizer,
None, # fc2_input_quantizer,
None, # fc2_weight_quantizer,
None, # fc2_output_quantizer,
None, # fc2_grad_input_quantizer,
None, # fc2_grad_weight_quantizer,
None, # fc2_grad_output_quantizer,
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # bias_gelu_fusion
None, # set_parallel_mode
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # activation
None, # activation_params
None, # normalization
None, # ub_overlap_ag
None, # ub_overlap_rs
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # gemm_gelu_fusion
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # debug
) )
...@@ -1827,8 +1782,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1827,8 +1782,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp) return self.onnx_forward(inp, is_grad_enabled)
debug = self.is_debug_iter() debug = self.is_debug_iter()
...@@ -1844,19 +1801,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1844,19 +1801,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True fp8_output = True
with torch.cuda.device( with self.prepare_forward(inp, num_gemms=2) as inp:
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = ( quantizers = (
self._get_quantizers(fp8_output) self._get_quantizers(fp8_output, is_grad_enabled)
if not debug if not debug
else self._get_debug_quantizers(fp8_output) else self._get_debug_quantizers(fp8_output, is_grad_enabled)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
quantizers = self._get_quantizers(fp8_output) quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
# Get quantizers # Get quantizers
( (
...@@ -1888,20 +1843,14 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1888,20 +1843,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False self.bias_gelu_nvfusion = False
if torch.is_grad_enabled(): if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply fwd_fn = _LayerNormMLP.apply
args = [] autograd_ctx = []
else: else:
fwd_fn = _LayerNormMLP.forward fwd_fn = _LayerNormMLP.forward
args = [None] autograd_ctx = [None]
args += (
inp, non_tensor_args = (
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
...@@ -1930,8 +1879,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1930,8 +1879,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output_gathered, self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug, self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode, self.set_parallel_mode,
torch.is_grad_enabled(), is_grad_enabled,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.activation, self.activation,
...@@ -1949,7 +1898,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1949,7 +1898,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.symmetric_ar_type, self.symmetric_ar_type,
debug, debug,
) )
out = fwd_fn(*args) out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
if self.return_layernorm_output: if self.return_layernorm_output:
out, ln_out = out out, ln_out = out
...@@ -1965,7 +1924,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1965,7 +1924,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
def _get_quantizers(self, fp8_output): def _get_quantizers(self, fp8_output, is_grad_enabled):
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_output_quantizer, fc1_output_quantizer,
...@@ -1995,7 +1954,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1995,7 +1954,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_output_quantizer = self.quantizers["scaling_fwd"][ fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT tex.FP8FwdTensors.GEMM2_OUTPUT
] ]
if torch.is_grad_enabled(): if is_grad_enabled:
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2 tex.FP8BwdTensors.GRAD_OUTPUT2
] ]
...@@ -2020,7 +1979,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2020,7 +1979,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer, fc2_grad_output_quantizer,
) )
def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: def onnx_forward(
self, inp: torch.Tensor, is_grad_enabled: bool
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
ONNX-compatible version of the forward function that provides numerical equivalence ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations. while only using operations that have defined ONNX symbolic translations.
...@@ -2037,7 +1998,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2037,7 +1998,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer, fc2_weight_quantizer,
output_quantizer, output_quantizer,
*_, *_,
) = self._get_quantizers(False) ) = self._get_quantizers(False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
fc1_weight, fc2_weight = self._get_weight_tensors() fc1_weight, fc2_weight = self._get_weight_tensors()
...@@ -2122,10 +2083,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2122,10 +2083,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
return fc2_out, fc2_bias.to(inp_dtype) return fc2_out, fc2_bias.to(inp_dtype)
return fc2_out return fc2_out
def _get_debug_quantizers(self, fp8_output): def _get_debug_quantizers(self, fp8_output, is_grad_enabled):
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
base_quantizers = list(self._get_quantizers(fp8_output)) base_quantizers = list(self._get_quantizers(fp8_output, is_grad_enabled))
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
def make_debug(prefix, offset): def make_debug(prefix, offset):
...@@ -2268,7 +2229,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2268,7 +2229,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
if not self.need_backward_dw(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): with get_nvtx_range_context("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
if self.use_bias and self.fc1_bias.grad is None: if self.use_bias and self.fc1_bias.grad is None:
(fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop() (fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
......
...@@ -19,7 +19,6 @@ from .base import ( ...@@ -19,7 +19,6 @@ from .base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad, get_dummy_wgrad,
get_ub, get_ub,
get_workspace,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
...@@ -38,6 +37,7 @@ from ..utils import ( ...@@ -38,6 +37,7 @@ from ..utils import (
assert_dim_for_all_gather, assert_dim_for_all_gather,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
get_nvtx_range_context,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -90,42 +90,46 @@ class _Linear(torch.autograd.Function): ...@@ -90,42 +90,46 @@ class _Linear(torch.autograd.Function):
weight: torch.Tensor, weight: torch.Tensor,
inp: torch.Tensor, inp: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
is_first_microbatch: Union[bool, None], non_tensor_args: Tuple,
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
ub_overlap_rs_fprop: bool,
ub_overlap_ag_dgrad: bool,
ub_overlap_ag_fprop: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_dgrad: bool,
ub_bulk_wgrad: bool,
ub_name: str,
fp8_output: bool, # pylint: disable=unused-argument
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
save_original_input: bool = False,
debug: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
(
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
fuse_wgrad_accumulation,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
parallel_mode,
is_grad_enabled,
ub_overlap_rs_fprop,
ub_overlap_ag_dgrad,
ub_overlap_ag_fprop,
ub_overlap_rs_dgrad,
ub_bulk_dgrad,
ub_bulk_wgrad,
ub_name,
fp8_output, # pylint: disable=unused-variable
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
save_original_input,
debug,
) = non_tensor_args
# NVTX label for profiling # NVTX label for profiling
nvtx_label = "transformer_engine._Linear.forward" nvtx_label = "transformer_engine._Linear.forward"
if ub_name is not None: if ub_name is not None:
...@@ -320,7 +324,6 @@ class _Linear(torch.autograd.Function): ...@@ -320,7 +324,6 @@ class _Linear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
inputmat_total, inputmat_total,
get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
...@@ -497,7 +500,7 @@ class _Linear(torch.autograd.Function): ...@@ -497,7 +500,7 @@ class _Linear(torch.autograd.Function):
if ctx.ub_name is not None: if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"): with get_nvtx_range_context("_Linear_backward"):
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors) restore_from_saved(ctx.tensor_objects, saved_tensors)
...@@ -719,7 +722,6 @@ class _Linear(torch.autograd.Function): ...@@ -719,7 +722,6 @@ class _Linear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm( gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8, weight_fp8,
grad_output, grad_output,
get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
...@@ -845,7 +847,6 @@ class _Linear(torch.autograd.Function): ...@@ -845,7 +847,6 @@ class _Linear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure # Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = { wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": ( "out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
...@@ -977,39 +978,7 @@ class _Linear(torch.autograd.Function): ...@@ -977,39 +978,7 @@ class _Linear(torch.autograd.Function):
wgrad, wgrad,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias, grad_bias,
None, # is_first_microbatch None,
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
None, # ub_overlap_rs_fprop
None, # ub_overlap_ag_dgrad
None, # ub_overlap_ag_fprop
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fp8_output
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # save_original_input
None, # debug
) )
...@@ -1403,8 +1372,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -1403,8 +1372,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
""" """
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output, is_grad_enabled)
debug = self.is_debug_iter() debug = self.is_debug_iter()
...@@ -1426,9 +1397,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1426,9 +1397,7 @@ class Linear(TransformerEngineBaseModule):
).is_fp8_ubuf(): ).is_fp8_ubuf():
fp8_grad = True fp8_grad = True
with torch.cuda.device( with self.prepare_forward(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor), allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp: ) as inp:
...@@ -1436,14 +1405,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -1436,14 +1405,14 @@ class Linear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = ( quantizers = (
self._get_quantizers(fp8_output, fp8_grad) self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
) )
if debug: if debug:
if self.no_debug_features_active(quantizers): if self.no_debug_features_active(quantizers):
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad) quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
( (
input_quantizer, input_quantizer,
...@@ -1454,16 +1423,14 @@ class Linear(TransformerEngineBaseModule): ...@@ -1454,16 +1423,14 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) = quantizers ) = quantizers
if torch.is_grad_enabled(): if is_grad_enabled:
linear_fn = _Linear.apply linear_fn = _Linear.apply
args = [] autograd_ctx = []
else: else:
linear_fn = _Linear.forward linear_fn = _Linear.forward
args = [None] autograd_ctx = [None]
args += (
weight_tensor, non_tensor_args = (
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
...@@ -1482,7 +1449,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1482,7 +1449,7 @@ class Linear(TransformerEngineBaseModule):
self.tp_size > 1, self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), is_grad_enabled,
self.ub_overlap_rs_fprop, self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad, self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop, self.ub_overlap_ag_fprop,
...@@ -1498,7 +1465,13 @@ class Linear(TransformerEngineBaseModule): ...@@ -1498,7 +1465,13 @@ class Linear(TransformerEngineBaseModule):
self.save_original_input, self.save_original_input,
debug, debug,
) )
out = linear_fn(*args) out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype) out = out + cast_if_needed(bias_tensor, self.activation_dtype)
...@@ -1506,7 +1479,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1506,7 +1479,7 @@ class Linear(TransformerEngineBaseModule):
return out, cast_if_needed(bias_tensor, self.activation_dtype) return out, cast_if_needed(bias_tensor, self.activation_dtype)
return out return out
def _get_quantizers(self, fp8_output, fp8_grad): def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8: if not self.fp8:
return [None] * 6 return [None] * 6
grad_input_quantizer = None grad_input_quantizer = None
...@@ -1518,7 +1491,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1518,7 +1491,7 @@ class Linear(TransformerEngineBaseModule):
(weight_quantizer,) = self._get_weight_quantizers() (weight_quantizer,) = self._get_weight_quantizers()
if fp8_output: if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled(): if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if fp8_grad: if fp8_grad:
...@@ -1532,8 +1505,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1532,8 +1505,8 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
) )
def _get_debug_quantizers(self, fp8_output, fp8_grad): def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad) original_quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
...@@ -1588,6 +1561,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1588,6 +1561,7 @@ class Linear(TransformerEngineBaseModule):
self, self,
inp: torch.Tensor, inp: torch.Tensor,
fp8_output: bool, fp8_output: bool,
is_grad_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
ONNX-compatible version of the forward function that provides numerical equivalence ONNX-compatible version of the forward function that provides numerical equivalence
...@@ -1604,7 +1578,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1604,7 +1578,7 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
*_, *_,
) = self._get_quantizers(fp8_output, False) ) = self._get_quantizers(fp8_output, False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
if input_quantizer is not None: if input_quantizer is not None:
......
...@@ -25,7 +25,6 @@ from ...module.base import ( ...@@ -25,7 +25,6 @@ from ...module.base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
get_dummy_wgrad, get_dummy_wgrad,
get_workspace,
) )
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer
...@@ -585,7 +584,6 @@ class BasicLinear(BasicOperation): ...@@ -585,7 +584,6 @@ class BasicLinear(BasicOperation):
y, *_ = general_gemm( y, *_ = general_gemm(
w, w,
x, x,
get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=output_quantizer, quantization_params=output_quantizer,
alpha=alpha, alpha=alpha,
...@@ -875,7 +873,6 @@ class BasicLinear(BasicOperation): ...@@ -875,7 +873,6 @@ class BasicLinear(BasicOperation):
dx, *_ = general_gemm( dx, *_ = general_gemm(
w, w,
dy, dy,
get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=grad_input_quantizer, quantization_params=grad_input_quantizer,
alpha=grad_input_alpha, alpha=grad_input_alpha,
...@@ -928,7 +925,6 @@ class BasicLinear(BasicOperation): ...@@ -928,7 +925,6 @@ class BasicLinear(BasicOperation):
dw, *_ = general_gemm( dw, *_ = general_gemm(
x, x,
dy, dy,
get_workspace(),
out_dtype=dw_dtype, out_dtype=dw_dtype,
alpha=grad_weight_alpha, alpha=grad_weight_alpha,
beta=grad_weight_beta, beta=grad_weight_beta,
......
...@@ -19,7 +19,6 @@ from ...module.base import ( ...@@ -19,7 +19,6 @@ from ...module.base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad, get_dummy_wgrad,
get_ub, get_ub,
get_workspace,
) )
from ...quantized_tensor import Quantizer from ...quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
...@@ -378,7 +377,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -378,7 +377,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dx, *_ = general_gemm( dx, *_ = general_gemm(
w, w,
dy, dy,
get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=grad_input_quantizer, quantization_params=grad_input_quantizer,
layout="NN", layout="NN",
...@@ -464,7 +462,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -464,7 +462,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dw, *_ = general_gemm( dw, *_ = general_gemm(
x, x,
dy, dy,
get_workspace(),
out_dtype=dw_dtype, out_dtype=dw_dtype,
accumulate=accumulate_into_grad_weight, accumulate=accumulate_into_grad_weight,
layout="NT", layout="NT",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment