"examples/vscode:/vscode.git/clone" did not exist on "27612051d0ef050e2659a4b88891e98b77cf45ef"
Unverified Commit e1edaaec authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Reduce CPU overheads (#2377)



Initial changes to remove pytorch overheads
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 42d22740
......@@ -2489,7 +2489,6 @@ class _custom_mha_fp8(torch.autograd.Function):
max_s: int,
fast_zero_fill: bool,
fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool,
mask_type: str,
quantizers: list[Quantizer],
......@@ -2518,7 +2517,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv, *_ = ext.general_gemm(
qkv_weight_fp8,
inp_fp8,
workspace,
bias=qkv_bias,
out_dtype=qkv_weight_fp8.dtype,
quantization_params=qkv_quantizer,
......@@ -2560,9 +2558,7 @@ class _custom_mha_fp8(torch.autograd.Function):
s_quantizer=s_quantizer,
)
tensors_to_save, tensor_objects = prepare_for_saving(
q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
)
tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
......@@ -2592,7 +2588,7 @@ class _custom_mha_fp8(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
saved_tensors = ctx.saved_tensors
(q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved(
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
ctx.tensor_objects, saved_tensors
)
......@@ -2648,7 +2644,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_dgrad, *_ = ext.general_gemm(
qkv_weight_fp8,
dqkv_c,
workspace,
ctx.dtype,
use_split_accumulator=_2X_ACC_DGRAD,
layout="NN",
......@@ -2658,7 +2653,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_wgrad, *_ = ext.general_gemm(
inp_fp8,
dqkv,
workspace,
ctx.dtype,
use_split_accumulator=_2X_ACC_WGRAD,
layout="NT",
......@@ -2709,9 +2703,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
with torch.no_grad():
self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0)
self.workspace = torch.empty(
_CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
)
def forward(
self,
......@@ -2730,7 +2721,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
max_s,
self.fast_zero_fill,
self.fp8_meta,
self.workspace,
self.training,
self.mask_type,
self.quantizers,
......
......@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split
out, *_ = tepytorch.cpp_extensions.general_gemm(
fp8_tensor1,
fp8_tensor2,
tepytorch.module.base.get_workspace(),
torch.float32,
use_split_accumulator=use_split_accumulator,
)
......@@ -199,7 +198,6 @@ def _emulate_linear(
wgrad, *_ = tepytorch.cpp_extensions.general_gemm(
wgrad_input,
wgrad_gradient,
tepytorch.module.base.get_workspace(),
torch.float32,
layout="NT",
grad=True,
......
......@@ -7,7 +7,7 @@ import sys
import pytest
import torch
import transformer_engine
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear
from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear, GroupedLinear
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
......@@ -19,7 +19,9 @@ model_configs = {
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"])
@pytest.mark.parametrize(
"module", ["TransformerLayer", "DotProductAttention", "Linear", "GroupedLinear"]
)
def test_current_device(model, module):
"""Test cases where current device is different from tensor device"""
......@@ -58,7 +60,7 @@ def test_current_device(model, module):
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
kwargs["max_seqlen_q"] = config.max_seqlen_q
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
if module == "DotProductAttention":
elif module == "DotProductAttention":
model = DotProductAttention(
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
)
......@@ -97,6 +99,24 @@ def test_current_device(model, module):
requires_grad=True,
)
]
elif module == "GroupedLinear":
num_gemms = 4
model = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
params_dtype=dtype,
device=f"cuda:{tensor_device}",
)
args = [
torch.randn(
(config.max_seqlen_q * config.batch_size * (num_gemms - 1), config.hidden_size),
dtype=dtype,
device=f"cuda:{tensor_device}",
requires_grad=True,
),
[0] + [config.max_seqlen_q * config.batch_size] * (num_gemms - 1), # Empty first split.
]
current_device_before = torch.cuda.current_device()
out = model(*args, **kwargs)
......
......@@ -44,7 +44,6 @@ from transformer_engine.pytorch import (
)
from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states
......@@ -2690,7 +2689,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm(
A[i],
B[i],
get_workspace(),
dtype,
grad=grad,
accumulate=accumulate,
......@@ -2705,7 +2703,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B,
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
......@@ -2760,7 +2757,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out, *_ = general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
quantization_params=out_quantizer,
bias=None,
......@@ -2770,7 +2766,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out, *_ = general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
quantization_params=None,
bias=None,
......@@ -2846,7 +2841,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm(
A_fp8[i],
B_fp8[i],
get_workspace(),
dtype,
out=out_ref[i],
accumulate=accumulate,
......@@ -2856,7 +2850,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
B_fp8,
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=m_splits,
accumulate=accumulate,
)
......
......@@ -36,7 +36,6 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from utils import ModelConfig
......@@ -912,7 +911,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_ = general_gemm(A=weight, B=inp, workspace=get_workspace())
_ = general_gemm(A=weight, B=inp)
torch.cuda.synchronize()
......@@ -936,7 +935,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
bias=None,
use_split_accumulator=False,
......
......@@ -19,7 +19,12 @@ from transformer_engine.pytorch.utils import (
get_device_compute_capability,
split_tensor_along_dim,
)
from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop
from transformer_engine.pytorch.utils import (
attention_mask_func,
nvtx_range_push,
nvtx_range_pop,
get_nvtx_range_context,
)
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
......@@ -1445,7 +1450,7 @@ class FusedAttnFunc(torch.autograd.Function):
dk = dk[..., : d_out.shape[-1]]
dv = dv[..., : d_out.shape[-1]]
else:
with torch.cuda.nvtx.range("FusedAttnFunc.backward"):
with get_nvtx_range_context("FusedAttnFunc.backward"):
# get nominal data type of dq, dk, dv
# FP16/BF16 attention: torch.float16 or torch.bfloat16
# FP8 attention: torch.float16 or torch.bfloat16
......
......@@ -975,7 +975,7 @@ class DotProductAttention(TransformerEngineBaseModule):
Whether to enforce output to be in FP8 or not.
"""
with torch.cuda.device(query_layer.device), self.prepare_forward(
with self.prepare_forward(
query_layer,
num_gemms=3,
allow_non_contiguous=True,
......
......@@ -6,23 +6,59 @@
from typing import Iterable, Optional, Tuple, Union, List
import os
import functools
import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor
from ..quantized_tensor import Quantizer
from ..quantized_tensor import Quantizer, QuantizedTensor, QuantizedTensorStorage
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_custom
from ..custom_recipes.gemm import custom_gemm
from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [
"general_gemm",
"general_grouped_gemm",
]
_NUM_MAX_UB_STREAMS = 3
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return 32 * 1024 * 1024 + 1024
return 4_194_304
@functools.lru_cache(maxsize=None)
def get_cublas_workspace(device: int, ub: bool, grouped_gemm: bool) -> torch.Tensor:
"""Returns workspace for cublas GEMM."""
assert not (ub and grouped_gemm), "UB is unsupported for grouped GEMM."
if ub:
return torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device
).repeat(_NUM_MAX_UB_STREAMS)
if grouped_gemm:
_multi_stream_cublas_workspace = []
for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
)
return _multi_stream_cublas_workspace
return torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device=device)
def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
"""Validate whether a GEMM scaling factor is consistent with its usage"""
if required:
......@@ -32,10 +68,35 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
return 0.0
def get_tensor_device(tensor: torch.Tensor) -> int:
"""Returns tensor device as an integer"""
if not isinstance(tensor, QuantizedTensorStorage):
return tensor.device.index
if isinstance(tensor, QuantizedTensor):
return tensor.device.index
if isinstance(tensor, (Float8BlockwiseQTensorStorage, MXFP8TensorStorage, NVFP4TensorStorage)):
return (
tensor._rowwise_data.device.index
if tensor._rowwise_data is not None
else tensor._columnwise_data.device.index
)
if isinstance(tensor, Float8TensorStorage):
return (
tensor._data.device.index
if tensor._data is not None
else tensor._transpose.device.index
)
try:
return (
tensor._data.device.index if tensor._data is not None else tensor._data_t.device.index
)
except AttributeError:
return torch.cuda.current_device()
def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
workspace: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
quantization_params: Optional[Quantizer] = None,
gelu: bool = False,
......@@ -62,6 +123,7 @@ def general_gemm(
alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate)
workspace = get_cublas_workspace(get_tensor_device(A), ub is not None, False)
if ub_type is not None:
assert ub is not None, (
......@@ -159,7 +221,6 @@ def general_grouped_gemm(
B: List[torch.Tensor],
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
layout: str = "TN",
m_splits: Optional[List[int]] = None,
gelu: bool = False,
......@@ -187,6 +248,8 @@ def general_grouped_gemm(
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
workspaces = get_cublas_workspace(get_tensor_device(A[0]), False, True)
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
......
......@@ -108,6 +108,11 @@ std::vector<py::object> fused_attn_fwd(
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) {
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(cu_seqlens_q.device());
auto none = py::none();
// create QKV tensor wrappers
......
......@@ -95,6 +95,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
bool bulk_overlap, float alpha, std::optional<float> beta) {
using namespace transformer_engine::pytorch::detail;
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace.device());
// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
......@@ -351,6 +356,11 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, at::Tensor counter) {
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace.device());
// TODO: Handle scaling modes
NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING;
NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING;
......@@ -400,6 +410,11 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
NVTE_ERROR("not implemented, D should be allocated for single output case.");
}
// Ensure that cublasLt handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(workspace[0].device());
void* output_data_ptr = nullptr;
if (single_output) {
output_data_ptr = (*D)[0].data_ptr();
......
......@@ -64,6 +64,11 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
// Input and param tensors
auto none = py::none();
const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
......@@ -294,6 +299,11 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
// Ensure that cuDNN handle is created on the correct device,
// overriding torch.cuda.set_device calls from user side.
// Assumes all tensors passed are on the same device.
at::cuda::CUDAGuard device_guard(input.cast<at::Tensor>().device());
// Input and param tensors
auto none = py::none();
const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none);
......
......@@ -39,13 +39,18 @@ from ..distributed import (
_fsdp_gather_tensors,
)
from ..constants import dist_group_type
from ..cpp_extensions.gemm import _NUM_MAX_UB_STREAMS
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
from ..utils import (
is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype,
get_nvtx_range_context,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState
......@@ -57,11 +62,8 @@ __all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_dummy_wgrads = {}
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = []
......@@ -75,35 +77,6 @@ class UserBufferQuantizationMode(Enum):
FP8 = "fp8"
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
# 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales
return 32 * 1024 * 1024 + 1024
return 4_194_304
def get_workspace() -> torch.Tensor:
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
)
return _cublas_workspace
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace:
for _ in range(tex.get_num_cublas_streams()):
_multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
return _multi_stream_cublas_workspace
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
"""Returns a dummy tensor of given shape."""
assert len(shape) == 2
......@@ -276,16 +249,6 @@ def initialize_ub(
flush=True,
)
# Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS:
# This ensures we don't do `.repeat()` on an already expanded workspace
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
).repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
layers_all_gather_overlap = [
"qkv_fprop",
......@@ -1078,8 +1041,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
delayed_scaling_recipe = self.fp8_meta["recipe"].delayed()
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."
......@@ -1091,25 +1056,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe:
if self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
if not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
# Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
# Activation recomputation is used and this is the first forward phase.
if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp
if self.fp8 and in_fp8_activation_recompute_phase():
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
def set_nccl_overlap_warning_if_tp(self) -> None:
......@@ -1531,7 +1498,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
if not self.need_backward_dw():
return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
with get_nvtx_range_context(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
weight_tensor = noop_cat(self._get_weight_tensors())
......@@ -1628,6 +1595,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
if not self.fp8 and not self.fp8_calibration:
return
if not self.primary_weights_in_fp8:
return
if not hasattr(self, "weight_names") or not self.weight_names:
return
......
......@@ -24,11 +24,14 @@ class _Fp8Padding(torch.autograd.Function):
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
padded_m_splits: List[int],
is_grad_enabled: bool,
non_tensor_args: Tuple,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args
# Make sure input dimensions are compatible
in_features = inp.shape[-1]
......@@ -65,7 +68,7 @@ class _Fp8Padding(torch.autograd.Function):
grad_output.view(-1, in_features), grad_input, ctx.padded_m_splits, ctx.m_splits
)
return (grad_input, None, None, None)
return grad_input, None
class Fp8Padding(torch.nn.Module):
......@@ -128,19 +131,20 @@ class Fp8Padding(torch.nn.Module):
if m_splits == padded_m_splits:
return inp, m_splits
if torch.is_grad_enabled():
is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
fn = _Fp8Padding.apply
args = []
autograd_ctx = []
else:
fn = _Fp8Padding.forward
args = [None]
autograd_ctx = [None]
args += (
inp,
non_tensor_args = (
m_splits,
padded_m_splits,
torch.is_grad_enabled(),
is_grad_enabled,
)
out = fn(*args)
out = fn(*autograd_ctx, inp, non_tensor_args)
return out, padded_m_splits
......@@ -4,7 +4,7 @@
"""FP8 Padding API"""
from typing import List, Optional
from typing import List, Optional, Tuple
import torch
......@@ -24,11 +24,14 @@ class _Fp8Unpadding(torch.autograd.Function):
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
padded_m_splits: List[int],
is_grad_enabled: bool,
non_tensor_args: Tuple,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(m_splits, padded_m_splits, is_grad_enabled) = non_tensor_args
in_features = inp.shape[-1]
# Allocate cast and transpose output tensor
......@@ -63,7 +66,7 @@ class _Fp8Unpadding(torch.autograd.Function):
grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits
)
return (grad_input, None, None, None)
return grad_input, None
class Fp8Unpadding(torch.nn.Module):
......@@ -126,19 +129,20 @@ class Fp8Unpadding(torch.nn.Module):
if m_splits == padded_m_splits:
return inp
if torch.is_grad_enabled():
is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
fn = _Fp8Unpadding.apply
args = []
autograd_ctx = []
else:
fn = _Fp8Unpadding.forward
args = [None]
autograd_ctx = [None]
args += (
inp,
non_tensor_args = (
m_splits,
padded_m_splits,
torch.is_grad_enabled(),
is_grad_enabled,
)
out = fn(*args)
out = fn(*autograd_ctx, inp, non_tensor_args)
return out
......@@ -14,7 +14,6 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from .base import (
get_dummy_wgrad,
get_multi_stream_cublas_workspace,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
......@@ -28,6 +27,7 @@ from ..utils import (
clear_tensor_data,
init_method_constant,
requires_grad,
get_nvtx_range_context,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -40,7 +40,6 @@ from ..cpp_extensions import (
)
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
......@@ -63,28 +62,34 @@ class _GroupedLinear(torch.autograd.Function):
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
grad_output_quantizers: List[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
sequence_parallel: bool,
activation_dtype: torch.dtype,
is_grad_enabled: bool,
module,
skip_fp8_weight_update,
save_original_input,
non_tensor_args: Tuple,
*weights_and_biases,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
m_splits,
use_bias,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_output_quantizers,
fuse_wgrad_accumulation,
cpu_offloading,
sequence_parallel,
activation_dtype,
is_grad_enabled,
module,
skip_fp8_weight_update,
save_original_input,
) = non_tensor_args
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:]
......@@ -183,7 +188,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmats,
[out],
activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True,
m_splits=m_splits,
bias=biases,
......@@ -284,7 +288,7 @@ class _GroupedLinear(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_GroupedLinear_backward"):
with get_nvtx_range_context("_GroupedLinear_backward"):
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
N = ctx.num_gemms
inputmats = saved_tensors[:N]
......@@ -372,7 +376,6 @@ class _GroupedLinear(torch.autograd.Function):
grad_output,
[dgrad],
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True,
layout="NN",
m_splits=ctx.m_splits,
......@@ -419,7 +422,6 @@ class _GroupedLinear(torch.autograd.Function):
grouped_gemm_wgrad = functools.partial(
general_grouped_gemm,
out_dtype=ctx.activation_dtype,
workspaces=get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
......@@ -484,28 +486,11 @@ class _GroupedLinear(torch.autograd.Function):
):
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
if ctx.reduce_and_update_bwd_fp8_tensors:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
*wgrad_list,
*grad_biases,
)
......@@ -765,16 +750,9 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
is_grad_enabled = torch.is_grad_enabled()
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
......@@ -794,7 +772,7 @@ class GroupedLinear(TransformerEngineBaseModule):
# TODO: use internal after #1638 is merged. # pylint: disable=fixme
for i in range(self.num_gemms):
input_quantizers[i].internal = False
if torch.is_grad_enabled():
if is_grad_enabled:
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
......@@ -804,14 +782,14 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
if torch.is_grad_enabled():
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
args = []
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
args = [None]
args += (
inp,
autograd_ctx = [None]
non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
......@@ -826,14 +804,12 @@ class GroupedLinear(TransformerEngineBaseModule):
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
torch.is_grad_enabled(),
is_grad_enabled,
self,
skip_fp8_weight_update,
None, # skip_fp8_weight_update
self.save_original_input,
*weight_tensors,
*bias_tensors,
)
out = linear_fn(*args)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
......@@ -846,7 +822,7 @@ class GroupedLinear(TransformerEngineBaseModule):
"""
if not self.need_backward_dw():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
with get_nvtx_range_context("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
......
......@@ -19,7 +19,6 @@ from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_custom
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
get_ub,
TransformerEngineBaseModule,
get_dummy_wgrad,
......@@ -40,6 +39,7 @@ from ..utils import (
nvtx_range_push,
requires_grad,
needs_quantized_gemm,
get_nvtx_range_context,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -96,47 +96,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
ub_overlap_ag_fprop: bool,
ub_overlap_rs_fprop: bool,
ub_overlap_ag_dgrad: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_name: str,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
non_tensor_args: Tuple,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
parallel_mode,
return_layernorm_output,
return_layernorm_output_gathered,
is_grad_enabled,
fwd_ln_sm_margin,
bwd_ln_sm_margin,
zero_centered_gamma,
normalization,
ub_overlap_ag_fprop,
ub_overlap_rs_fprop,
ub_overlap_ag_dgrad,
ub_overlap_rs_dgrad,
ub_bulk_wgrad,
ub_bulk_dgrad,
ub_name,
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
debug,
) = non_tensor_args
# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward"
if ub_name is not None:
......@@ -355,7 +361,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat,
ln_out_total,
get_workspace(),
quantization_params=output_quantizer,
out_dtype=activation_dtype,
bias=bias,
......@@ -544,7 +549,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
with get_nvtx_range_context("_LayerNormLinear_backward"):
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
......@@ -731,7 +736,6 @@ class _LayerNormLinear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm(
weight,
grad_output,
get_workspace(),
layout="NN",
grad=True,
quantization_params=ctx.grad_input_quantizer,
......@@ -858,7 +862,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
......@@ -1026,44 +1029,7 @@ class _LayerNormLinear(torch.autograd.Function):
dbeta,
wgrad,
grad_bias,
None, # eps
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # normalization
None, # ub_overlap_ag_fprop
None, # ub_overlap_rs_fprop
None, # ub_overlap_ag_dgrad
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fsdp_group
None, # debug
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None,
)
......@@ -1523,8 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
return self.onnx_forward(inp, fp8_output, is_grad_enabled)
debug = self.is_debug_iter()
......@@ -1546,9 +1514,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
).is_fp8_ubuf():
fp8_grad = True
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
with self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
......@@ -1556,14 +1522,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad)
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
(
input_quantizer,
......@@ -1574,18 +1540,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer,
) = quantizers
if torch.is_grad_enabled():
if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
args = []
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
......@@ -1607,8 +1568,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
......@@ -1625,7 +1586,15 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.symmetric_ar_type,
debug,
)
out = fwd_fn(*args)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
if self.return_layernorm_output:
out, ln_out = out
......@@ -1641,7 +1610,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out
return out
def _get_quantizers(self, fp8_output, fp8_grad):
def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8:
return [None] * 6
grad_input_quantizer = None
......@@ -1653,7 +1622,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled():
if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True
if fp8_grad:
......@@ -1668,8 +1637,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
......@@ -1694,6 +1663,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self,
inp: torch.Tensor,
fp8_output: bool,
is_grad_enabled: bool,
) -> torch.Tensor:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
......@@ -1709,7 +1679,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(fp8_output, fp8_grad=False)
) = self._get_quantizers(fp8_output, False, is_grad_enabled)
inp_dtype = inp.dtype
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
......
......@@ -20,7 +20,6 @@ from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_custom
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
_ub_communicators,
get_ub,
TransformerEngineBaseModule,
......@@ -45,6 +44,7 @@ from ..utils import (
clear_tensor_data,
requires_grad,
needs_quantized_gemm,
get_nvtx_range_context,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -174,55 +174,61 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor,
fc2_bias: torch.Tensor,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer],
fc1_output_quantizer: Optional[Quantizer],
fc1_grad_input_quantizer: Optional[Quantizer],
fc1_grad_weight_quantizer: Optional[Quantizer],
fc1_grad_output_quantizer: Optional[Quantizer],
fc2_input_quantizer: Optional[Quantizer],
fc2_weight_quantizer: Optional[Quantizer],
fc2_output_quantizer: Optional[Quantizer],
fc2_grad_input_quantizer: Optional[Quantizer],
fc2_grad_weight_quantizer: Optional[Quantizer],
fc2_grad_output_quantizer: Optional[Quantizer],
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
return_layernorm_output: bool,
return_layernorm_output_gathered: bool,
bias_gelu_fusion: bool,
set_parallel_mode: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
activation_params: Optional[dict],
normalization: str,
ub_overlap_ag: bool,
ub_overlap_rs: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
gemm_gelu_fusion: bool,
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
debug: Optional[bool] = False,
non_tensor_args: Tuple,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# Reduce number of arguments to autograd function in order
# to reduce CPU overhead due to pytorch arg checking.
(
eps,
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
return_layernorm_output,
return_layernorm_output_gathered,
bias_gelu_fusion,
set_parallel_mode,
is_grad_enabled,
fwd_ln_sm_margin,
bwd_ln_sm_margin,
zero_centered_gamma,
activation,
activation_params,
normalization,
ub_overlap_ag,
ub_overlap_rs,
ub_overlap_rs_dgrad,
ub_bulk_wgrad,
ub_bulk_dgrad,
gemm_gelu_fusion,
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
debug,
) = non_tensor_args
# Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
......@@ -433,7 +439,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_outputs = general_gemm(
fc1_weight_final,
ln_out_total,
get_workspace(),
quantization_params=(
fc2_input_quantizer
if gemm_gelu_fusion
......@@ -517,7 +522,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final,
act_out,
get_workspace(),
out_dtype=activation_dtype,
bias=fc2_bias,
quantization_params=fc2_output_quantizer,
......@@ -704,7 +708,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
with get_nvtx_range_context("_LayerNormMLP_backward"):
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
......@@ -874,7 +878,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_output, *_ = general_gemm(
fc2_weight,
grad_output,
get_workspace(),
layout="NN",
grad=True,
quantization_params=(
......@@ -968,7 +971,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
......@@ -1138,7 +1140,6 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm(
fc1_weight,
dact,
get_workspace(),
out=gemm_out,
out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer,
......@@ -1217,7 +1218,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
fc1_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
......@@ -1399,52 +1399,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias_grad if fc1_bias is not None else None,
fc2_wgrad, # pylint: disable=possibly-used-before-assignment
fc2_bias_grad,
None, # eps
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer,
None, # fc1_weight_quantizer,
None, # fc1_output_quantizer,
None, # fc1_grad_input_quantizer,
None, # fc1_grad_weight_quantizer,
None, # fc1_grad_output_quantizer,
None, # fc2_input_quantizer,
None, # fc2_weight_quantizer,
None, # fc2_output_quantizer,
None, # fc2_grad_input_quantizer,
None, # fc2_grad_weight_quantizer,
None, # fc2_grad_output_quantizer,
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # return_layernorm_output
None, # return_layernorm_output_gathered
None, # bias_gelu_fusion
None, # set_parallel_mode
None, # is_grad_enabled
None, # fwd_ln_sm_margin
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # activation
None, # activation_params
None, # normalization
None, # ub_overlap_ag
None, # ub_overlap_rs
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # gemm_gelu_fusion
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # debug
None,
)
......@@ -1827,8 +1782,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode():
return self.onnx_forward(inp)
return self.onnx_forward(inp, is_grad_enabled)
debug = self.is_debug_iter()
......@@ -1844,19 +1801,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
with self.prepare_forward(inp, num_gemms=2) as inp:
quantizers = (
self._get_quantizers(fp8_output)
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output)
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output)
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
# Get quantizers
(
......@@ -1888,20 +1843,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
if torch.is_grad_enabled():
if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
args = []
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
......@@ -1930,8 +1879,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
......@@ -1949,7 +1898,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.symmetric_ar_type,
debug,
)
out = fwd_fn(*args)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
if self.return_layernorm_output:
out, ln_out = out
......@@ -1965,7 +1924,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return out, ln_out
return out
def _get_quantizers(self, fp8_output):
def _get_quantizers(self, fp8_output, is_grad_enabled):
(
fc1_input_quantizer,
fc1_output_quantizer,
......@@ -1995,7 +1954,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_output_quantizer = self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_OUTPUT
]
if torch.is_grad_enabled():
if is_grad_enabled:
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
]
......@@ -2020,7 +1979,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer,
)
def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
def onnx_forward(
self, inp: torch.Tensor, is_grad_enabled: bool
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
......@@ -2037,7 +1998,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(False)
) = self._get_quantizers(False, is_grad_enabled)
inp_dtype = inp.dtype
fc1_weight, fc2_weight = self._get_weight_tensors()
......@@ -2122,10 +2083,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
return fc2_out, fc2_bias.to(inp_dtype)
return fc2_out
def _get_debug_quantizers(self, fp8_output):
def _get_debug_quantizers(self, fp8_output, is_grad_enabled):
from ...debug.pytorch.debug_quantization import DebugQuantizer
base_quantizers = list(self._get_quantizers(fp8_output))
base_quantizers = list(self._get_quantizers(fp8_output, is_grad_enabled))
assert TEDebugState.debug_enabled
def make_debug(prefix, offset):
......@@ -2268,7 +2229,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
"""
if not self.need_backward_dw():
return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"):
with get_nvtx_range_context("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
if self.use_bias and self.fc1_bias.grad is None:
(fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
......
......@@ -19,7 +19,6 @@ from .base import (
fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub,
get_workspace,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
......@@ -38,6 +37,7 @@ from ..utils import (
assert_dim_for_all_gather,
nvtx_range_pop,
nvtx_range_push,
get_nvtx_range_context,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -90,42 +90,46 @@ class _Linear(torch.autograd.Function):
weight: torch.Tensor,
inp: torch.Tensor,
bias: Optional[torch.Tensor],
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
grad_input_quantizer: Optional[Quantizer],
grad_weight_quantizer: Optional[Quantizer],
grad_output_quantizer: Optional[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
ub_overlap_rs_fprop: bool,
ub_overlap_ag_dgrad: bool,
ub_overlap_ag_fprop: bool,
ub_overlap_rs_dgrad: bool,
ub_bulk_dgrad: bool,
ub_bulk_wgrad: bool,
ub_name: str,
fp8_output: bool, # pylint: disable=unused-argument
fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module,
skip_fp8_weight_update: bool,
symmetric_ar_type: str,
save_original_input: bool = False,
debug: Optional[bool] = False,
non_tensor_args: Tuple,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
(
is_first_microbatch,
fp8,
fp8_calibration,
wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
fuse_wgrad_accumulation,
cpu_offloading,
tp_group,
tp_size,
sequence_parallel,
tensor_parallel,
activation_dtype,
parallel_mode,
is_grad_enabled,
ub_overlap_rs_fprop,
ub_overlap_ag_dgrad,
ub_overlap_ag_fprop,
ub_overlap_rs_dgrad,
ub_bulk_dgrad,
ub_bulk_wgrad,
ub_name,
fp8_output, # pylint: disable=unused-variable
fsdp_group,
module,
skip_fp8_weight_update,
symmetric_ar_type,
save_original_input,
debug,
) = non_tensor_args
# NVTX label for profiling
nvtx_label = "transformer_engine._Linear.forward"
if ub_name is not None:
......@@ -320,7 +324,6 @@ class _Linear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat,
inputmat_total,
get_workspace(),
quantization_params=output_quantizer,
out_dtype=activation_dtype,
bias=bias,
......@@ -497,7 +500,7 @@ class _Linear(torch.autograd.Function):
if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"):
with get_nvtx_range_context("_Linear_backward"):
saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
......@@ -719,7 +722,6 @@ class _Linear(torch.autograd.Function):
gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8,
grad_output,
get_workspace(),
layout="NN",
grad=True,
quantization_params=ctx.grad_input_quantizer,
......@@ -845,7 +847,6 @@ class _Linear(torch.autograd.Function):
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
......@@ -977,39 +978,7 @@ class _Linear(torch.autograd.Function):
wgrad,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
None, # grad_input_quantizer
None, # grad_weight_quantizer
None, # grad_output_quantizer
None, # fuse_wgrad_accumulation
None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype
None, # parallel_mode
None, # is_grad_enabled
None, # ub_overlap_rs_fprop
None, # ub_overlap_ag_dgrad
None, # ub_overlap_ag_fprop
None, # ub_overlap_rs_dgrad
None, # ub_bulk_dgrad
None, # ub_bulk_wgrad
None, # ub_name
None, # fp8_output
None, # fsdp_group
None, # module
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # save_original_input
None, # debug
None,
)
......@@ -1403,8 +1372,10 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
is_grad_enabled = torch.is_grad_enabled()
if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output)
return self.onnx_forward(inp, fp8_output, is_grad_enabled)
debug = self.is_debug_iter()
......@@ -1426,9 +1397,7 @@ class Linear(TransformerEngineBaseModule):
).is_fp8_ubuf():
fp8_grad = True
with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
with self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
......@@ -1436,14 +1405,14 @@ class Linear(TransformerEngineBaseModule):
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad)
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad)
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad)
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
(
input_quantizer,
......@@ -1454,16 +1423,14 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer,
) = quantizers
if torch.is_grad_enabled():
if is_grad_enabled:
linear_fn = _Linear.apply
args = []
autograd_ctx = []
else:
linear_fn = _Linear.forward
args = [None]
args += (
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
autograd_ctx = [None]
non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
......@@ -1482,7 +1449,7 @@ class Linear(TransformerEngineBaseModule):
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
......@@ -1498,7 +1465,13 @@ class Linear(TransformerEngineBaseModule):
self.save_original_input,
debug,
)
out = linear_fn(*args)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
......@@ -1506,7 +1479,7 @@ class Linear(TransformerEngineBaseModule):
return out, cast_if_needed(bias_tensor, self.activation_dtype)
return out
def _get_quantizers(self, fp8_output, fp8_grad):
def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
if not self.fp8:
return [None] * 6
grad_input_quantizer = None
......@@ -1518,7 +1491,7 @@ class Linear(TransformerEngineBaseModule):
(weight_quantizer,) = self._get_weight_quantizers()
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
if torch.is_grad_enabled():
if is_grad_enabled:
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True
if fp8_grad:
......@@ -1532,8 +1505,8 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer,
)
def _get_debug_quantizers(self, fp8_output, fp8_grad):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad)
def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled):
original_quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
assert TEDebugState.debug_enabled
from ...debug.pytorch.debug_quantization import DebugQuantizer
......@@ -1588,6 +1561,7 @@ class Linear(TransformerEngineBaseModule):
self,
inp: torch.Tensor,
fp8_output: bool,
is_grad_enabled: bool,
) -> torch.Tensor:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
......@@ -1604,7 +1578,7 @@ class Linear(TransformerEngineBaseModule):
weight_quantizer,
output_quantizer,
*_,
) = self._get_quantizers(fp8_output, False)
) = self._get_quantizers(fp8_output, False, is_grad_enabled)
inp_dtype = inp.dtype
if input_quantizer is not None:
......
......@@ -25,7 +25,6 @@ from ...module.base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
get_workspace,
)
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
......@@ -585,7 +584,6 @@ class BasicLinear(BasicOperation):
y, *_ = general_gemm(
w,
x,
get_workspace(),
out_dtype=dtype,
quantization_params=output_quantizer,
alpha=alpha,
......@@ -875,7 +873,6 @@ class BasicLinear(BasicOperation):
dx, *_ = general_gemm(
w,
dy,
get_workspace(),
out_dtype=dtype,
quantization_params=grad_input_quantizer,
alpha=grad_input_alpha,
......@@ -928,7 +925,6 @@ class BasicLinear(BasicOperation):
dw, *_ = general_gemm(
x,
dy,
get_workspace(),
out_dtype=dw_dtype,
alpha=grad_weight_alpha,
beta=grad_weight_beta,
......
......@@ -19,7 +19,6 @@ from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub,
get_workspace,
)
from ...quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
......@@ -378,7 +377,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dx, *_ = general_gemm(
w,
dy,
get_workspace(),
out_dtype=dtype,
quantization_params=grad_input_quantizer,
layout="NN",
......@@ -464,7 +462,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dw, *_ = general_gemm(
x,
dy,
get_workspace(),
out_dtype=dw_dtype,
accumulate=accumulate_into_grad_weight,
layout="NT",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment