Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
......@@ -81,15 +81,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic);
std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::vector<int64_t> window_size, bool bottom_right_diagonal,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
......@@ -99,10 +100,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size,
bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......@@ -198,6 +199,11 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
* Activations
**************************************************************************************************/
/* GLU (sigmoid gate) */
py::object glu(const at::Tensor &input, py::handle quantizer);
py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
/* GELU and variants*/
py::object gelu(const at::Tensor &input, py::handle quantizer);
......@@ -585,6 +591,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
~CommOverlap() {}
using transformer_engine::CommOverlapCore::copy_into_buffer;
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
at::Tensor get_buffer(bool local_chunk = false,
......@@ -606,6 +613,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
~CommOverlapP2P() {}
using transformer_engine::CommOverlapP2PBase::copy_into_buffer;
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
at::Tensor get_buffer(bool local_chunk = false,
......
......@@ -246,6 +246,14 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua
return dactivation_helper<nvte_dgelu, nullptr>(grad, input, quantizer);
}
py::object glu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_glu, nullptr>(input, quantizer, 2);
}
py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dglu, nullptr>(grad, input, quantizer);
}
py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu, nullptr>(input, quantizer, 2);
}
......
......@@ -45,7 +45,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph) {
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) {
#ifdef __HIP_PLATFORM_AMD__
return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
#else
......@@ -53,7 +53,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right,
return_max_logit, cuda_graph);
return_max_logit, cuda_graph, deterministic);
return fused_attention_backend;
#endif
}
......@@ -104,9 +104,10 @@ std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::vector<int64_t> window_size, bool bottom_right_diagonal,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
......@@ -242,7 +243,7 @@ std::vector<py::object> fused_attn_fwd(
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(),
softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(),
at::cuda::getCurrentCUDAStream());
});
......@@ -302,7 +303,7 @@ std::vector<py::object> fused_attn_fwd(
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(),
softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(),
at::cuda::getCurrentCUDAStream());
});
......@@ -318,10 +319,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size,
bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......@@ -543,14 +544,14 @@ std::vector<py::object> fused_attn_bwd(
// populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], deterministic, cuda_graph,
workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// allocate memory for workspace
......@@ -560,14 +561,14 @@ std::vector<py::object> fused_attn_bwd(
// execute kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], deterministic, cuda_graph,
workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// destroy tensor wrappers
......
......@@ -132,6 +132,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
/* GLU (sigmoid gate) */
m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"),
py::arg("quantizer"));
/* GELU and variants*/
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
......@@ -158,6 +161,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu,
"SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"),
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* Backward of GLU */
m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
......@@ -515,8 +521,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0,
py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("copy_into_buffer",
static_cast<void (CommOverlap::*)(const at::Tensor &, bool)>(
&CommOverlap::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlap::get_communication_stream);
......@@ -533,8 +541,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1,
py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
py::arg("use_ce") = true, py::arg("aggregate") = false)
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("copy_into_buffer",
static_cast<void (CommOverlapP2P::*)(const at::Tensor &, bool)>(
&CommOverlapP2P::copy_into_buffer),
py::arg("input"), py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlapP2P::get_communication_stream);
......
......@@ -729,8 +729,8 @@ def checkpoint(
if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
setattr(function, "fsdp_group", None)
function.fast_setattr("fsdp_wrapped", False)
function.fast_setattr("fsdp_group", None)
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
......@@ -2022,7 +2022,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
)
root_state = _get_module_fsdp_state(fsdp_root)
assert root_state is not None, "Root module does not have a valid _FSDPState."
setattr(fsdp_root.module, "fsdp_group", root_state.process_group)
fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group)
# Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
......@@ -2033,7 +2033,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.quantized_model_init(...) context."
)
setattr(fsdp_module.module, "fsdp_group", state.process_group)
fsdp_module.module.fast_setattr("fsdp_group", state.process_group)
class FullyShardedDataParallel(FSDP):
......
......@@ -451,11 +451,12 @@ def _make_graphed_callables(
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
outputs_requiring_grad = tuple(
o for o in outputs if o is not None and o.requires_grad
)
torch.autograd.backward(
tuple(o for o in outputs if o.requires_grad),
grad_tensors=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
outputs_requiring_grad,
grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad),
)
grad_inputs = tuple(input.grad for input in inputs)
......@@ -616,19 +617,22 @@ def _make_graphed_callables(
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys = tuple(
(o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad
(o.shape, o.dtype, o.layout)
for o in static_outputs
if o is not None and o.requires_grad
)
if static_grad_outputs_keys in static_grad_outputs_dict:
static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys]
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
......@@ -636,7 +640,9 @@ def _make_graphed_callables(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
tuple(
o for o in static_outputs if o is not None and o.requires_grad
),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
......@@ -719,7 +725,8 @@ def _make_graphed_callables(
):
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
......@@ -727,7 +734,7 @@ def _make_graphed_callables(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
tuple(o for o in static_outputs if o is not None and o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
......@@ -794,7 +801,7 @@ def _make_graphed_callables(
# Replay forward graph
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
return tuple(o.detach() if o is not None else o for o in static_outputs)
@staticmethod
@torch.autograd.function.once_differentiable
......@@ -853,12 +860,22 @@ def _make_graphed_callables(
return functionalized
def make_graphed_attribute_functions(graph_idx):
# Get te modules for current graph
te_modules = visited_te_modules.get(graph_idx, set())
# Attach backward_dw as an attribute to the graphed callable.
def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay()
# Trigger the grad accumulation hook for wgrad graphs.
for module in te_modules:
if (
isinstance(module, TransformerEngineBaseModule)
and module.need_backward_dw()
):
module._trigger_wgrad_accumulation_and_reduce_hooks()
# Attach reset as an attribute to the graphed callable.
def reset():
fwd_graphs[graph_idx].reset()
......
......@@ -47,17 +47,35 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: (
f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
def no_torch_dynamo(recursive=True):
"""Decorator to disable Torch Dynamo, except during ONNX export."""
def decorator(f):
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
disabled_f = (
torch._dynamo.disable(f, recursive=recursive)
if torch.__version__ >= "2.1"
else torch._dynamo.disable(f)
)
@wraps(f)
def wrapper(*args, **kwargs):
if is_in_onnx_export_mode():
return f(*args, **kwargs)
return disabled_f(*args, **kwargs)
return wrapper
return decorator
else:
# Fallback for PyTorch < 2.0: no-op decorator
def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument
"""No-op decorator for PyTorch < 2.0."""
return lambda func: func
def set_jit_fusion_options() -> None:
......
......@@ -90,6 +90,8 @@ class _NoopCatFunc(torch.autograd.Function):
# Check first tensor
if not tensors:
raise ValueError("Attempted to concatenate 0 tensors")
# Check concat dim
num_dims = tensors[0].dim()
if not -num_dims <= dim < num_dims:
raise ValueError(
......@@ -122,11 +124,24 @@ class _NoopCatFunc(torch.autograd.Function):
ctx.dim = dim
ctx.split_ranges = split_ranges
# Out-of-place concatenation if needed
# Tensor properties from first tensor
dtype = tensors[0].dtype
device = tensors[0].device
strides = tensors[0].stride()
data_ptr_stride = strides[dim] * tensors[0].element_size()
# Out-of-place concatenation when view tensors have different storage
# Note: This works around an edge case with the split_quantize
# function, which might allocate a buffer and construct
# subviews. However, in order to reduce CPU overheads, these
# views are configured manually outside of PyTorch. PyTorch
# doesn't know these views share the same memory, and it
# blocks us from reconstructing the full tensor because it
# thinks we are accessing out-of-bounds memory.
if tensors[0].untyped_storage().nbytes() < out_shape[dim] * data_ptr_stride:
return torch.cat(tensors, dim=dim)
# Out-of-place concatenation if tensor properties do not match
data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride
for tensor in tensors[1:]:
if (
......@@ -139,13 +154,7 @@ class _NoopCatFunc(torch.autograd.Function):
data_ptr += tensor.size(dim) * data_ptr_stride
# No-op concatenation
out = tensors[0].new()
out.set_(
tensors[0].untyped_storage(),
tensors[0].storage_offset(),
out_shape,
strides,
)
out = tensors[0].as_strided(out_shape, strides)
out.requires_grad = any(tensor.requires_grad for tensor in tensors)
return out
......
......@@ -10,9 +10,8 @@ import pickle
import warnings
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager
import logging
from types import MethodType
import torch
......@@ -50,6 +49,8 @@ from ..utils import (
is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype,
get_nvtx_range_context,
nvtx_range_push,
nvtx_range_pop,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe
......@@ -644,10 +645,10 @@ def fill_userbuffers_buffer_for_all_gather(
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
def __init__(self) -> None:
def __init__(self, name: Optional[str] = None) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None
self.name = name
self.next_iter_when_debug_should_be_run = 0
self.fp8_initialized = False
self.fp8 = False
......@@ -672,26 +673,22 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
self._validate_name()
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
"activation_dtype",
"fp8",
"fp8_initialized",
"fp8_calibration",
"fp8_parameters",
}
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
def fast_setattr(self, name: str, value: Any) -> None:
"""
Fast version of the Module's set attribute function.
Should be used for regular attributes, but not properties nor parameters/buffers.
"""
self.__dict__[name] = value
def module_setattr(self, name: str, value: Any) -> None:
"""
Regular version of the Module's set attribute function.
Should be used only when the fast version cannot be used - for the properties,
parameters and buffers.
"""
super().__setattr__(name, value)
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""
......@@ -812,7 +809,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.set_meta_tensor(True, recipe)
self.set_meta_tensor(False, recipe)
self.fp8_meta_tensors_initialized = True
self.fast_setattr("fp8_meta_tensors_initialized", True)
def get_fp8_meta_tensors(self) -> None:
"""Get scales and amaxes."""
......@@ -969,7 +966,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch_get_autocast_gpu_dtype()
self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype())
return
# All checks after this have already been performed once, thus skip
......@@ -984,7 +981,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
self.fast_setattr("activation_dtype", dtype)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""
......@@ -996,8 +993,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
tp_group : ProcessGroup, default = None
tensor parallel process group.
"""
self.tp_group = tp_group
self.tp_group_initialized = True
self.fast_setattr("tp_group", tp_group)
self.fast_setattr("tp_group_initialized", True)
def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
"""returns the FP8 weights."""
......@@ -1013,48 +1010,51 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# assume FP8 execution.
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
_original_recipe = self.fp8_meta.get("recipe", None)
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
if (
self.fp8_initialized
and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
):
meta = self.fp8_meta
fp8 = FP8GlobalStateManager.is_fp8_enabled()
fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fast_setattr("fp8_parameters", fp8_parameters)
self.fast_setattr("fp8", fp8)
self.fast_setattr("fp8_calibration", fp8_calibration)
fp8_enabled = fp8 or fp8_calibration
meta["fp8_checkpoint"] = fp8_enabled
_original_recipe = None
if fp8_parameters or fp8_enabled:
_original_recipe = meta.get("recipe", None)
if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe:
# FP8 init has already been run and recipe is the same, don't do anything.
return
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
self.fast_setattr("fp8_initialized", False)
return
if self.fp8_parameters and not self.fp8_initialized:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
if fp8_parameters and not self.fp8_initialized:
meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(meta["recipe"])
if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
meta["num_gemms"] = num_gemms
meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe
if hasattr(self.fp8_meta["recipe"], "fp8_format"):
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
if hasattr(meta["recipe"], "fp8_format"):
meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd
meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
self.fp8_initialized = True
self.init_fp8_meta_tensors(meta["recipe"])
self.fast_setattr("fp8_initialized", True)
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
_current_recipe = self.fp8_meta["recipe"]
_current_recipe = meta["recipe"]
if _original_recipe is not None and not (
issubclass(_current_recipe.__class__, _original_recipe.__class__)
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
......@@ -1067,22 +1067,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Clear cached workspaces as they were created with the old recipe/quantizer type
self._fp8_workspaces.clear()
@contextmanager
def prepare_forward(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True
) -> torch.Tensor:
"""Checks and prepares for FWD execution."""
self.fast_setattr(
"allow_different_data_and_param_types", allow_different_data_and_param_types
)
self.fast_setattr("forwarded_at_least_once", True)
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
......@@ -1113,13 +1109,37 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp
nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
return inp
def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()
@contextmanager
def prepare_forward_ctx(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prepares for FWD execution."""
inp = self.prepare_forward(
inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types
)
try:
yield inp
finally:
self.end_forward()
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
......@@ -1354,9 +1374,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Update the parameter based on its type
if not is_dtensor:
setattr(self, name, param)
self.module_setattr(name, param)
else:
setattr(self, name, dtensor_param)
self.module_setattr(name, dtensor_param)
@abstractmethod
def forward(self):
......@@ -1545,8 +1565,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad
del bgrad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()
def _trigger_wgrad_accumulation_and_reduce_hooks(self):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def is_debug_iter(self) -> bool:
"""
......@@ -1555,7 +1581,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug = TEDebugState.debug_enabled
if not debug:
return False
self._validate_name()
# If layer is run first time in new iteration,
# we need to check if the debug should be enabled for this layer -
......@@ -1569,14 +1594,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug = False
else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration()
self.debug_enabled_in_this_iteration = debug
self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
self.fast_setattr("debug_enabled_in_this_iteration", debug)
else:
# If this is the same iteration as previous invocation of the module,
# we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration
self.debug_last_iteration = TEDebugState.get_iteration()
self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
if self.wgrad_store is not None:
if debug and self.wgrad_store.delay_wgrad_compute():
......@@ -1592,7 +1617,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Sometimes features inform that they will not be enabled for particular layer
# for multiple next iterations.
self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers)
self.fast_setattr(
"next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers)
)
if not run_current:
return True
......@@ -1604,22 +1631,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def _validate_name(self):
"""
Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable.
It creates a default name with layer count as the variable
which may be changed by the user of the module.
"""
if self.name is not None:
return
assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api
if self.name is None:
debug_api.log_message(
"Names are not provided to debug modules. ",
"Creating and using generic names. Pass names to debug modules for better"
" insight. ",
level=logging.WARNING,
)
self.name = f"Layer_{TEDebugState.get_layer_count()}"
self.name = f"Layer_{TEDebugState.get_layer_count()}"
def _check_weight_tensor_recipe_correspondence(self) -> None:
"""
......
......@@ -15,6 +15,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from .base import (
get_dummy_wgrad,
TransformerEngineBaseModule,
......@@ -149,7 +150,10 @@ class _GroupedLinear(torch.autograd.Function):
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
inp_view,
m_splits,
input_quantizers,
disable_bulk_allocation=cpu_offloading,
)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
......@@ -367,7 +371,10 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = DebugQuantizer.multi_tensor_quantize(
grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype
grad_output_view,
ctx.grad_output_quantizers,
ctx.m_splits,
ctx.activation_dtype,
)
else:
# Only split grad output. Grad bias is fused with
......@@ -438,7 +445,8 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.input_quantizers[0] is not None:
for input_quantizer in ctx.input_quantizers:
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
input_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
input_quantizer.set_usage(rowwise=True, columnwise=True)
else:
......@@ -448,7 +456,10 @@ class _GroupedLinear(torch.autograd.Function):
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype
inp_view,
ctx.input_quantizers,
ctx.m_splits,
ctx.activation_dtype,
)
else:
inputmats = torch.split(
......@@ -623,9 +634,9 @@ class GroupedLinear(TransformerEngineBaseModule):
save_original_input: bool = False,
name: Optional[str] = None,
) -> None:
super().__init__()
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms
self.in_features = in_features
self.out_features = out_features
......@@ -640,13 +651,19 @@ class GroupedLinear(TransformerEngineBaseModule):
assert (
not ub_overlap_rs and not ub_overlap_ag
), "GroupedLinear doesn't support Userbuffer overlap."
self.init_method = init_method
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
self._offsets = {
"input": 0,
"weight": 1,
"output": 2,
"grad_output": 0,
"grad_input": 1,
}
self._num_fp8_tensors_per_gemm = {
"fwd": 3,
"bwd": 2,
......@@ -688,7 +705,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
dtype=self.params_dtype,
),
),
init_fn=init_method,
......@@ -704,13 +721,13 @@ class GroupedLinear(TransformerEngineBaseModule):
torch.empty(
self.out_features,
device=device,
dtype=params_dtype,
dtype=self.params_dtype,
),
),
init_fn=init_method_constant(0.0),
)
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
bias = torch.Tensor().to(dtype=self.params_dtype, device=device)
setattr(self, f"bias{i}", bias)
if self.primary_weights_in_fp8:
......@@ -734,8 +751,63 @@ class GroupedLinear(TransformerEngineBaseModule):
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
def make_grouped_weights(self, defer_init=False) -> None:
"""
Convert parameters into a GroupedTensor and re-register them as parameters.
"""
if defer_init:
return
weight_quantizers = self._get_weight_quantizers()
recipe = (
weight_quantizers[0]._get_compatible_recipe()
if weight_quantizers and weight_quantizers[0] is not None
else None
)
if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()):
self.set_tensor_parallel_attributes(defer_init=defer_init)
return
weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
# Create the weight storage.
grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=self.num_gemms,
shape=[(self.out_features, self.in_features)] * self.num_gemms,
quantizer=weight_quantizers[0],
dtype=self.params_dtype,
device=weights[0].device,
)
# Copy existing params into storage.
with torch.no_grad():
for i in range(self.num_gemms):
if self.primary_weights_in_fp8:
grouped_weights.quantized_tensors[i].copy_from_storage(weights[i])
else:
grouped_weights.quantized_tensors[i].copy_(weights[i])
# Re-register the grouped weights as parameters.
for i in range(self.num_gemms):
self.register_parameter(
f"weight{i}",
torch.nn.Parameter(grouped_weights.quantized_tensors[i]),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"],
)
self.set_tensor_parallel_attributes(defer_init=defer_init)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
# Grouped tensor weights is an opt-in feature.
if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))):
self.make_grouped_weights(defer_init=defer_init)
def set_tensor_parallel_attributes(self, defer_init=False) -> None:
"""Set attributes needed for TP"""
if not defer_init:
# Set parallelism attributes for linear weights
......@@ -798,7 +870,8 @@ class GroupedLinear(TransformerEngineBaseModule):
is_grad_enabled = torch.is_grad_enabled()
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
......@@ -853,6 +926,9 @@ class GroupedLinear(TransformerEngineBaseModule):
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
finally:
self.end_forward()
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
......@@ -879,8 +955,7 @@ class GroupedLinear(TransformerEngineBaseModule):
del grad_biases_
del wgrad_list
del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
......@@ -932,7 +1007,7 @@ class GroupedLinear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8 and not self.fp8_calibration:
if not self.fp8 and not self.fp8_calibration and not self.primary_weights_in_fp8:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
......@@ -941,7 +1016,7 @@ class GroupedLinear(TransformerEngineBaseModule):
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
weight_quantizers[i].internal = not self.primary_weights_in_fp8
return weight_quantizers
def _get_quantizers(self):
......
......@@ -1177,9 +1177,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: str = None,
name: Optional[str] = None,
) -> None:
super().__init__()
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
......@@ -1198,7 +1198,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.symmetric_ar_type = symmetric_ar_type
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name
if tp_group is None:
self.tp_size = tp_size
......@@ -1527,10 +1526,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
).is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)
try:
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
......@@ -1609,6 +1609,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
non_tensor_args,
)
finally:
self.end_forward()
if self.return_layernorm_output:
out, ln_out = out
......
......@@ -107,6 +107,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return {
"gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
......@@ -123,6 +124,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
......@@ -145,6 +147,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return {
"gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
......@@ -1695,7 +1698,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
Options: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : dict, default = None
Additional parameters for the activation function.
......@@ -1817,7 +1820,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False,
name: str = None,
name: Optional[str] = None,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False,
......@@ -1826,7 +1829,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
symmetric_ar_type: Optional[str] = None,
checkpoint: bool = False,
) -> None:
super().__init__()
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......@@ -1857,7 +1860,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
for use_fp8 in [False, True]
)
)
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
......@@ -1915,7 +1917,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias = None
# FC1 init
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]:
if self.activation in [
"geglu",
"glu",
"qgeglu",
"reglu",
"sreglu",
"swiglu",
"clamped_swiglu",
]:
fc1_output_features = 2 * self.size_per_partition
else:
fc1_output_features = self.size_per_partition
......@@ -2077,8 +2087,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp:
inp = self.prepare_forward(inp, num_gemms=2)
try:
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
......@@ -2118,7 +2129,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if ( not IS_HIP_EXTENSION
and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute() ):
self.bias_gelu_nvfusion = False
self.fast_setattr("bias_gelu_nvfusion", False)
if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
......@@ -2188,6 +2199,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
non_tensor_args,
)
finally:
self.end_forward()
if self.return_layernorm_output:
out, ln_out = out
......@@ -2336,6 +2350,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"glu": lambda x: torch.sigmoid(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1],
......@@ -2534,5 +2549,4 @@ class LayerNormMLP(TransformerEngineBaseModule):
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()
......@@ -429,8 +429,8 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx.weight_object = weight
if cpu_offloading:
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
......@@ -1103,7 +1103,7 @@ class Linear(TransformerEngineBaseModule):
save_original_input: bool = False,
name: Optional[str] = None,
) -> None:
super().__init__()
super().__init__(name)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
......@@ -1116,7 +1116,6 @@ class Linear(TransformerEngineBaseModule):
self.rng_tracker_name = rng_tracker_name
self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input
self.name = name
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
......@@ -1400,11 +1399,8 @@ class Linear(TransformerEngineBaseModule):
).is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
try:
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
......@@ -1475,6 +1471,8 @@ class Linear(TransformerEngineBaseModule):
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
finally:
self.end_forward()
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
......
......@@ -8,7 +8,9 @@ This operation-based API is experimental and subject to change.
"""
from transformer_engine.pytorch.ops.basic import *
from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential
from .basic import *
from .fuser import register_backward_fusion, register_forward_fusion
from .linear import Linear
from .op import BasicOperation, FusedOperation, FusibleOperation
from .sequential import Sequential
from . import fused
......@@ -7,6 +7,7 @@
from .activation import (
GELU,
GEGLU,
GLU,
QGELU,
QGEGLU,
ReLU,
......@@ -14,8 +15,6 @@ from .activation import (
SReLU,
SReGLU,
SiLU,
SwiGLU,
ClampedSwiGLU,
)
from .add_extra_input import AddExtraInput
from .all_gather import AllGather
......@@ -24,6 +23,7 @@ from .basic_linear import BasicLinear
from .bias import Bias
from .constant_scale import ConstantScale
from .dropout import Dropout
from .grouped_linear import GroupedLinear
from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm
......@@ -32,3 +32,4 @@ from .quantize import Quantize
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
from .rmsnorm import RMSNorm
from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU
......@@ -20,6 +20,7 @@ from .._common import maybe_dequantize
__all__ = [
"GELU",
"GEGLU",
"GLU",
"QGELU",
"QGEGLU",
"ReLU",
......@@ -27,8 +28,6 @@ __all__ = [
"SReLU",
"SReGLU",
"SiLU",
"SwiGLU",
"ClampedSwiGLU",
]
......@@ -164,6 +163,38 @@ class GELU(_ActivationOperation):
return tex.dgelu(*args, **kwargs)
class GLU(_ActivationOperation):
r"""Gated Linear Unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GLU}(a,b) = \sigma(a) * b
where :math:`\sigma` is the sigmoid function.
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
See `Language Modeling with Gated Convolutional Networks<https://arxiv.org/abs/1612.08083>`__
and `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.glu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dglu(*args, **kwargs)
class GEGLU(_ActivationOperation):
r"""Gaussian Error Gated Linear Unit
......@@ -355,76 +386,3 @@ class SiLU(_ActivationOperation):
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dsilu(*args, **kwargs)
class SwiGLU(_ActivationOperation):
r"""Swish gated linear unit
The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:
.. math::
\text{GEGLU}(a,b) = \text{SiLU}(a) * b
where
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
`GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__
and `Gaussian Error Linear Units (GELUs)<https://arxiv.org/abs/1606.08415>`__.
"""
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.swiglu(*args, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dswiglu(*args, **kwargs)
class ClampedSwiGLU(_ActivationOperation):
r"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit : float
The clamp limit.
alpha : float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def __init__(
self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False
):
super().__init__(cache_quantized_input=cache_quantized_input)
self.limit = limit
self.alpha = alpha
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for grouped linear layer."""
from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
import contextlib
import math
from typing import Any, Optional
import torch
import transformer_engine_torch as tex
from ...cpp_extensions import general_grouped_gemm
from ...distributed import CudaRNGStatesTracker
from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
)
from ...quantization import FP8GlobalStateManager, Recipe
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
round_up_to_nearest_multiple,
)
from .._common import is_quantized_tensor, maybe_dequantize
from ..op import BasicOperation, OperationContext
class GroupedLinear(BasicOperation):
r"""Apply multiple linear transformations: :math:``y_i = x_i W_i^T + b_i``
This feature is experimental and subject to change.
This is equivalent to splitting the input tensor along its first
dimension, applying a separate ``torch.nn.Linear`` to each split,
and concatenating along the first dimension.
Parameters
----------
num_groups : int
Number of linear transformations.
in_features : int
Inner dimension of input tensor.
out_features : int
Inner dimension of output tensor.
bias : bool, default = ``True``
Apply additive bias.
device : torch.device, default = default CUDA device
Tensor device.
dtype : torch.dtype, default = default dtype
Tensor datatype.
rng_state_tracker_function : callable
Function that returns ``CudaRNGStatesTracker``, which is used
for model-parallel weight initialization.
accumulate_into_main_grad : bool, default = ``False``
Whether to directly accumulate weight gradients into the
weight's ``main_grad`` attribute instead of relying on PyTorch
autograd. The weight's ``main_grad`` must be set externally
and there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with
Megatron-LM. This argument along with weight tensor having
attribute ``overwrite_main_grad`` set to True will overwrite
``main_grad`` instead of accumulating.
"""
# Operation expects input split sizes
num_extra_inputs: int = 1
def __init__(
self,
num_groups: int,
in_features: int,
out_features: int,
*,
bias: bool = True,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
) -> None:
super().__init__()
# Weight tensor dimensions
self.num_groups: int = num_groups
self.in_features: int = in_features
self.out_features: int = out_features
if self.num_groups <= 0:
raise ValueError(f"Invalid number of groups ({self.num_groups})")
if self.in_features <= 0:
raise ValueError(f"Invalid input size ({self.in_features})")
if self.out_features <= 0:
raise ValueError(f"Invalid output size ({self.out_features})")
# Weight tensor attributes
device = canonicalize_device(device)
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Initialize recipe state if needed for natively quantized weight
self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters()
if self._with_quantized_weight:
self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
# RNG state tracker
self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]]
self._rng_state_tracker_function = rng_state_tracker_function
# Register weights
self.weight0: torch.nn.Parameter
for group_idx in range(self.num_groups):
weight_tensor = torch.empty(
self.out_features,
self.in_features,
device="meta",
dtype=dtype,
)
self.register_parameter(
f"weight{group_idx}",
torch.nn.Parameter(weight_tensor),
)
# Register biases
self.bias0: Optional[torch.nn.Parameter]
for group_idx in range(self.num_groups):
bias_tensor = None
if bias:
bias_tensor = torch.empty(
self.out_features,
device="meta",
dtype=dtype,
)
bias_tensor = torch.nn.Parameter(bias_tensor)
self.register_parameter(f"bias{group_idx}", bias_tensor)
# Initialize weights if needed
if device.type != "meta":
self.reset_parameters()
# Whether to accumulate weight gradient into main_grad
self._accumulate_into_main_grad: bool = accumulate_into_main_grad
def num_quantizers(self, mode: str) -> int:
if mode == "forward":
return 2 * self.num_groups
if mode == "backward":
return self.num_groups
return 0
@property
def has_bias(self) -> bool:
"""Whether an additive bias is being applied"""
return self.bias0 is not None
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
# Parameter device
device = self.weight0.device
if device.type == "meta":
device = canonicalize_device(None)
# Initialize weight values
# Note: Allocate a single buffer in order to support grouped
# GEMM kernels that expect a single weight buffer.
packed_weights = torch.empty(
self.num_groups,
self.out_features,
self.in_features,
dtype=self.weight0.dtype,
device=device,
)
weights = [packed_weights[idx] for idx in range(self.num_groups)]
for weight in weights:
init_context = contextlib.nullcontext()
if self._rng_state_tracker_function is not None:
init_context = self._rng_state_tracker_function().fork()
with init_context:
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
# Quantize weights if needed
if self._with_quantized_weight:
# Configure quantizers
quantizers = [
self.get_quantizer("forward", 2 * idx + 1) for idx in range(self.num_groups)
]
with_rowwise_usage = True
with_columnwise_usage = torch.is_grad_enabled()
for quantizer in quantizers:
if quantizer is None:
raise RuntimeError(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within quantized_model_init, but the forward pass was not "
"performed within autocast."
)
quantizer.set_usage(
rowwise=with_rowwise_usage,
columnwise=with_columnwise_usage,
)
quantizer.internal = False
# Quantize weights
weights = self._quantize_weights(weights, quantizers)
# Register weights
for group_idx, weight in enumerate(weights):
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
setattr(self, f"weight{group_idx}", weight)
# Initialize biases if needed
if self.bias0 is not None:
packed_biases = torch.zeros(
self.num_groups,
self.out_features,
dtype=self.bias0.dtype,
device=device,
)
for group_idx in range(self.num_groups):
bias = torch.nn.Parameter(packed_biases[group_idx])
setattr(self, f"bias{group_idx}", bias)
def _quantize_weights(
self,
weights: Sequence[torch.Tensor],
quantizers: Sequence[Quantizer],
) -> Sequence[torch.Tensor]:
"""Construct quantized weight tensors."""
# Manually construct MXFP8 weights
if isinstance(quantizers[0], MXFP8Quantizer):
return self._quantize_weights_mxfp8(weights, quantizers)
# Use quantizers to construct quantized weights
with torch.no_grad():
return [quantizer(weight) for quantizer, weight in zip(quantizers, weights)]
def _quantize_weights_mxfp8(
self,
weights: Sequence[torch.Tensor],
quantizers: Sequence[Quantizer],
) -> Sequence[MXFP8Tensor]:
"""Construct MXFP8 weight tensors.
Instead of allocating separate buffers for each weight tensor,
this function constructs large buffers and assigns subviews to
each tensor. This is intended to support grouped GEMM kernels
that expect packed buffers.
"""
# Tensor dimensions
num_groups = len(weights)
out_features, in_features = weights[0].size()
packed_shape = (num_groups, out_features, in_features)
unpacked_shape = (out_features, in_features)
# Tensor attributes
device = weights[0].device
dtype = weights[0].dtype
requires_grad = torch.is_grad_enabled()
with_rowwise_usage = quantizers[0].rowwise_usage
with_columnwise_usage = quantizers[0].columnwise_usage
# Construct packed buffers
rowwise_data = [None] * num_groups
rowwise_scales = [None] * num_groups
columnwise_data = [None] * num_groups
columnwise_scales = [None] * num_groups
if with_rowwise_usage:
scale_shape = (
num_groups,
round_up_to_nearest_multiple(out_features, 128),
round_up_to_nearest_multiple(in_features // 32, 4),
)
packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device)
packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device)
rowwise_data = [packed_data[idx] for idx in range(num_groups)]
rowwise_scales = [packed_scales[idx] for idx in range(num_groups)]
if with_columnwise_usage:
scale_shape = (
num_groups,
round_up_to_nearest_multiple(out_features // 32, 4),
round_up_to_nearest_multiple(in_features, 128),
)
packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device)
packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device)
columnwise_data = [packed_data[idx] for idx in range(num_groups)]
columnwise_scales = [packed_scales[idx] for idx in range(num_groups)]
# Construct MXFP8 tensors and cast to MXFP8
out = []
with torch.no_grad():
for group_idx in range(num_groups):
weight = MXFP8Tensor(
shape=unpacked_shape,
dtype=dtype,
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise_data=rowwise_data[group_idx],
rowwise_scale_inv=rowwise_scales[group_idx],
columnwise_data=columnwise_data[group_idx],
columnwise_scale_inv=columnwise_scales[group_idx],
quantizer=quantizers[group_idx],
requires_grad=requires_grad,
with_gemm_swizzled_scales=False,
)
weight.copy_(weights[group_idx])
out.append(weight)
return out
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
# Initialize params if needed
if any(param.device.type == "meta" for param in self.parameters()):
self.reset_parameters()
# Check that weights are consistent
dtype = self.weight0.dtype
device = self.weight0.device
weight_requires_grad = self.weight0.requires_grad
weight_tensor_type = type(self.weight0.data)
for group_idx in range(self.num_groups):
weight = getattr(self, f"weight{group_idx}")
if weight.dtype != dtype:
raise RuntimeError(
f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})."
)
if not devices_match(weight.device, device):
raise RuntimeError(
f"Weight {group_idx} has invalid device "
f"(expected {device}, got {weight.device})."
)
if weight.requires_grad != weight_requires_grad:
raise RuntimeError(
f"Weight {group_idx} has requires_grad={weight.requires_grad}, "
f"but expected requires_grad={weight_requires_grad}."
)
if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck
raise RuntimeError(
f"Weight {group_idx} has invalid tensor type "
f"(expected {weight_tensor_type.__name__}, "
f"got {type(weight.data).__name__})."
)
# Check that biases are consistent
for group_idx in range(self.num_groups):
bias = getattr(self, f"bias{group_idx}")
if self.has_bias:
if bias is None:
raise RuntimeError(f"Expected biases, but bias {group_idx} is uninitialized")
if bias.dtype != dtype:
raise RuntimeError(
f"Bias {group_idx} has invalid dtype (expected {dtype}, got {bias.dtype})."
)
if not devices_match(bias.device, device):
raise RuntimeError(
f"Bias {group_idx} has invalid device "
f"(expected {device}, got {bias.device})."
)
if bias.requires_grad != weight_requires_grad:
raise RuntimeError(
f"Bias {group_idx} has requires_grad={bias.requires_grad}, "
f"but expected requires_grad={weight_requires_grad}."
)
else:
if bias is not None:
raise RuntimeError(f"Expected no biases, but bias {group_idx} is initialized")
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
super().pre_fuser_forward(requires_grad=requires_grad)
if FP8GlobalStateManager.is_fp8_enabled():
# Assume weights have consistent grad requirement
weight_requires_grad = requires_grad and self.weight0.requires_grad
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
for group_idx in range(self.num_groups):
input_quantizer = self.get_quantizer("forward", 2 * group_idx)
weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1)
grad_output_quantizer = self.get_quantizer("backward", group_idx)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)
for group_idx in range(self.num_groups):
# Input/grad output quantizers use internal tensors
input_quantizer = self.get_quantizer("forward", 2 * group_idx)
grad_output_quantizer = self.get_quantizer("backward", group_idx)
if input_quantizer is not None:
input_quantizer.internal = True
if grad_output_quantizer is not None:
grad_output_quantizer.internal = True
# Handle weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1)
if weight_quantizer is None:
pass
elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)):
# Make sure weight param has correct quantizer
weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
weight_quantizer.internal = False
getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy())
else:
# Use internal tensors if quantized weights will not be
# exposed externally
weight_quantizer.internal = (
not FP8GlobalStateManager.with_fp8_parameters()
and not getattr(self, "_with_quantized_weight", False)
)
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if recipe is not None:
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
grad_output_quantizer.force_pow_2_scales = (
recipe.fp8_quant_bwd_grad.power_2_scale
)
grad_output_quantizer.amax_epsilon_scales = (
recipe.fp8_quant_bwd_grad.amax_epsilon
)
def op_forward(self, *args, **kwargs):
raise RuntimeError(
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)
def op_backward(self, *args, **kwargs):
raise RuntimeError(
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
num_groups = self.num_groups
has_bias = self.has_bias
device = self.weight0.device
# Check which grads are required
ctx = basic_op_ctxs[0]
input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad
# Quantizers
input_quantizers = [None] * num_groups
weight_quantizers = [None] * num_groups
grad_output_quantizers = [None] * num_groups
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
for group_idx in range(num_groups):
input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx)
weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1)
grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx)
# Get autocast dtype if needed
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = self.weight0.dtype
# Extract split sizes from extra input
split_sizes = basic_op_extra_inputs[0][0]
split_sizes_int = [int(s) for s in split_sizes.tolist()]
if len(split_sizes_int) != num_groups:
raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.")
# Extract params
weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)]
bs = None
if has_bias:
bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups)]
# Convert weight dtype if needed
ws = []
for w, quantizer in zip(weights, weight_quantizers):
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not is_quantized_tensor(w):
quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = quantizer(w)
ws.append(w)
# Split input tensor and convert dtypes if needed
x = maybe_dequantize(input_, dtype)
xs = None
if with_quantized_compute:
for quantizer in input_quantizers:
quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
xs = tex.split_quantize(x, split_sizes_int, input_quantizers)
else:
xs = torch.split(x, split_sizes_int)
# Allocate output tensor
in_shape = list(input_.size())
out_shape = in_shape[:-1] + [self.out_features]
out = torch.empty(out_shape, dtype=dtype, device=device)
# Perform GEMMs
general_grouped_gemm(
ws,
xs,
[out],
[None] * num_groups, # quantization_params
dtype,
m_splits=split_sizes_int,
bias=bs,
use_bias=has_bias,
use_split_accumulator=_2X_ACC_FPROP,
single_output=True,
)
# Prepare weight tensors for backward pass
if not input_requires_grad:
ws = [None] * num_groups
elif with_quantized_compute:
for w, weight_param in zip(ws, weights):
if w is not weight_param:
w.update_usage(rowwise_usage=False, columnwise_usage=True)
# Prepare input tensor for backward pass
if not weight_requires_grad:
xs = [None] * num_groups
elif with_quantized_compute:
for x in xs:
x.update_usage(rowwise_usage=False, columnwise_usage=True)
# Save state for backward pass
if ctx.requires_grad:
ctx.save_for_backward(split_sizes, *xs, *ws)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizers = input_quantizers
ctx.weight_quantizers = weight_quantizers
ctx.grad_output_quantizers = grad_output_quantizers
ctx.grad_input_quantizers = None
ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
return out, [()]
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
num_groups = self.num_groups
has_bias = self.has_bias
device = self.weight0.device
# Saved tensors from forward pass
ctx = basic_op_ctxs[0]
saved_tensors = ctx.saved_tensors
split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:]
xs, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:]
ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:]
# Split grad output tensor and convert dtypes if needed
split_sizes_int = [int(s) for s in split_sizes.tolist()]
dy = maybe_dequantize(grad_output, ctx.dtype)
dys = None
grad_biases = [None] * num_groups
if ctx.with_quantized_compute:
for quantizer in ctx.grad_output_quantizers:
quantizer.set_usage(
rowwise=ctx.input_requires_grad,
columnwise=ctx.weight_requires_grad,
)
dys = tex.split_quantize(dy, split_sizes_int, ctx.grad_output_quantizers)
if has_bias:
grad_biases = [
dy.reshape(-1, dy.size(-1)).sum(dim=0)
for dy in torch.split(grad_output, split_sizes_int)
]
else:
dys = torch.split(dy, split_sizes_int)
if has_bias:
grad_biases = [dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys]
# Initialize grad weight buffers
accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weights = [None] * num_groups
if ctx.weight_requires_grad:
if accumulate_into_main_grad:
# Megatron-LM wgrad fusion
# Note: Get grad tensors from params so we can
# accumulate directly into it.
for group_idx in range(num_groups):
weight_param = getattr(self, f"weight{group_idx}")
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
grad_weights[group_idx] = weight_param.main_grad
accumulate_into_main_grad = not getattr(self.weight0, "overwrite_main_grad", False)
else:
weight_shape = ws[0].size()
for group_idx in range(num_groups):
grad_weights[group_idx] = torch.empty(
weight_shape,
dtype=ctx.dtype,
device=device,
)
else:
accumulate_into_main_grad = False
# Perform dgrad GEMMs
grad_input = None
if ctx.input_requires_grad:
out_shape = list(grad_output.size())
in_shape = out_shape[:-1] + [self.in_features]
grad_input = torch.empty(
in_shape,
dtype=ctx.dtype,
device=device,
)
general_grouped_gemm(
ws,
dys,
[grad_input],
[None] * num_groups, # quantization_params
ctx.dtype,
layout="NN",
m_splits=split_sizes_int,
use_split_accumulator=_2X_ACC_DGRAD,
single_output=True,
)
# Perform wgrad GEMMs
if ctx.weight_requires_grad:
general_grouped_gemm(
xs,
dys,
grad_weights,
[None] * num_groups, # quantization_params
ctx.dtype,
layout="NT",
m_splits=split_sizes_int,
use_split_accumulator=_2X_ACC_WGRAD,
accumulate=accumulate_into_main_grad,
)
# Clear input tensors if possible
clear_tensor_data(*xs)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weights = [None] * num_groups
for group_idx in range(num_groups):
weight_param = getattr(self, f"weight{group_idx}")
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weights[group_idx] = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
grad_params = grad_weights + grad_biases if has_bias else grad_weights
return grad_input, [grad_params], [(None,)]
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for SwiGLU and variants."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
__all__ = ["SwiGLU", "ClampedSwiGLU", "ScaledSwiGLU"]
class SwiGLU(BasicOperation):
r"""Swish gated linear unit
The input tensor is split into chunks :math:``a`` and :math:``b``
along the last dimension and the following is computed:
.. math::
\text{SwiGLU}(a,b) = \text{SiLU}(a) * b
where
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:``a`` and
:math:``b``. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
``GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>``__.
Parameters
----------
cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
highly experimental.
glu_interleave_size : int, optional
When set, the GLU activations will use a block interleaved
format. Instead of interpreting the input tensor as a
concatenation of gates and linear units (e.g.
:math:``[a_1, a_2, a_3, a_4, b_1, b_2, b_3, b_4]``
in the above notation), it will be interpreted
as alternating blocks of gates and linear units (e.g.
:math:``[a_1, a_2, b_1, b_2, a_3, a_4, b_3, b_4]``
when the interleave size is 2). This data format is highly
experiental and is primarily intended to support some advanced
fused kernels.
"""
def __init__(
self,
*,
cache_quantized_input: bool = False,
glu_interleave_size: Optional[int] = None,
):
super().__init__()
self.cache_quantized_input: bool = cache_quantized_input
self.glu_interleave_size: Optional[int] = glu_interleave_size
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dtype
dtype: torch.dtype
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = input_.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise RuntimeError(f"Unsupported dtype ({dtype})")
# Check input tensor
input_ = maybe_dequantize(input_.contiguous(), dtype)
# Remove interleaving if needed
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Launch kernel
out = tex.swiglu(swiglu_in, next_op_input_quantizer)
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
input_quantizer = Float8CurrentScalingQuantizer(
tex.DType.kFloat8E4M3,
input_.device,
)
input_quantizer.set_usage(rowwise=True, columnwise=False)
input_ = input_quantizer(input_)
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
ctx.save_for_backward(input_)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(input_,) = ctx.saved_tensors
# Make sure tensors have correct dtypes
x = maybe_dequantize(input_.contiguous(), ctx.dtype)
dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype)
# Remove interleaving if needed
swiglu_in = x
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Quantizer for grad input
quantizer = ctx.prev_op_grad_output_quantizer
if self.glu_interleave_size is not None:
quantizer = None
# Launch kernel
grad_swiglu_in = tex.dswiglu(dy, swiglu_in, quantizer)
# Apply interleaving if needed
dx = grad_swiglu_in
if self.glu_interleave_size is not None:
shape = dx.size()
dx = dx.reshape(
-1,
2,
shape[-1] // (2 * self.glu_interleave_size),
self.glu_interleave_size,
)
dx = dx.transpose(1, 2).contiguous()
dx = dx.view(shape)
# Clear input tensor if possible
clear_tensor_data(input_)
return dx, ()
class ClampedSwiGLU(BasicOperation):
r"""GPT-OSS
Implementation based on ``GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>``__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is different
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit : float
The clamp limit.
alpha : float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input : bool, default = ``False``
Quantize input tensor when caching for use in the backward pass.
glu_interleave_size : int, optional
When set, the GLU activations will use an experimental block
interleaved format. See the corresponding option in the SwiGLU
operation for more details.
"""
def __init__(
self,
*,
limit: float = 7.0,
alpha: float = 1.702,
cache_quantized_input: bool = False,
glu_interleave_size: Optional[int] = None,
):
super().__init__()
self.limit: float = limit
self.alpha: float = alpha
self.cache_quantized_input: bool = cache_quantized_input
self.glu_interleave_size: Optional[int] = glu_interleave_size
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dtype
dtype: torch.dtype
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
else:
dtype = input_.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise RuntimeError(f"Unsupported dtype ({dtype})")
# Check input tensor
x = maybe_dequantize(input_.contiguous(), dtype)
# Remove interleaving if needed
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Launch kernel
out = tex.clamped_swiglu(
swiglu_in,
next_op_input_quantizer,
limit=self.limit,
alpha=self.alpha,
)
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device)
input_quantizer.set_usage(rowwise=True, columnwise=False)
x = input_quantizer(x)
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(input_,) = ctx.saved_tensors
# Make sure tensors have correct dtypes
x = maybe_dequantize(input_.contiguous(), ctx.dtype)
dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype)
# Remove interleaving if needed
swiglu_in = x
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Quantizer for grad input
quantizer = ctx.prev_op_grad_output_quantizer
if self.glu_interleave_size is not None:
quantizer = None
# Launch kernel
grad_swiglu_in = tex.clamped_dswiglu(
dy,
swiglu_in,
quantizer,
limit=self.limit,
alpha=self.alpha,
)
# Apply interleaving if needed
dx = grad_swiglu_in
if self.glu_interleave_size is not None:
shape = dx.size()
dx = dx.reshape(
-1,
2,
shape[-1] // (2 * self.glu_interleave_size),
self.glu_interleave_size,
)
dx = dx.transpose(1, 2).contiguous()
dx = dx.view(shape)
# Clear input tensor if possible
clear_tensor_data(input_)
return dx, ()
class ScaledSwiGLU(BasicOperation):
r"""SwiGLU with post-scaling.
If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is
multiplied with an extra input tensor of shape
``(d_1, ..., d_{n-1})``.
Parameters
----------
glu_interleave_size : int, optional
When set, the GLU activations will use an experimental block
interleaved format. See the corresponding option in the SwiGLU
operation for more details.
"""
# Operation expects scales
num_extra_inputs: int = 1
def __init__(self, glu_interleave_size: Optional[int] = None):
super().__init__()
self.glu_interleave_size: Optional[int] = glu_interleave_size
def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)
def op_backward(self, *args, **kwargs) -> None:
raise RuntimeError(
f"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)
def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
extra_input = basic_op_extra_inputs[0][0]
# Determine compute dtype
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
elif isinstance(input_, torch.Tensor):
dtype = input_.dtype
else:
dtype = extra_input.dtype
# Make sure inputs are in correct dtype
input_ = maybe_dequantize(input_, dtype)
scales = maybe_dequantize(extra_input, dtype)
# Remove gate interleaving if needed
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Compute scaled SwiGLU
swiglu_out = tex.swiglu(swiglu_in, None)
out = swiglu_out * scales.unsqueeze(-1)
# Save state for backward pass
ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
ctx.save_for_backward(
input_,
scales if ctx.input_requires_grad else None,
)
return out, [()]
def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
ctx = basic_op_ctxs[0]
input_, scales = ctx.saved_tensors
input_ = maybe_dequantize(input_, ctx.dtype)
if scales is not None:
scales = maybe_dequantize(scales, ctx.dtype)
grad_output = maybe_dequantize(grad_output, ctx.dtype)
# Remove gate interleaving if needed
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
# Compute input grad
grad_input = None
if ctx.input_requires_grad:
grad_swiglu_out = grad_output * scales.unsqueeze(-1)
grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None)
grad_input = grad_swiglu_in
if self.glu_interleave_size is not None:
shape = grad_input.size()
grad_input = grad_input.reshape(
-1,
2,
shape[-1] // (2 * self.glu_interleave_size),
self.glu_interleave_size,
)
grad_input = grad_input.transpose(1, 2).contiguous()
grad_input = grad_input.view(shape)
# Compute scales grad by recomputing SwiGLU
grad_extra_input = None
if ctx.extra_input_requires_grad:
swiglu_out = tex.swiglu(swiglu_in, None)
grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output)
# Clear input tensor if possible
clear_tensor_data(ctx.saved_tensors[0]) # input_
return grad_input, [()], [(grad_extra_input,)]
......@@ -4,39 +4,27 @@
"""Compound tensor operation supported by the operation fuser."""
from .backward_activation_bias import (
BackwardActivationBias,
fuse_backward_activation_bias,
)
from .backward_add_rmsnorm import (
BackwardAddRMSNorm,
fuse_backward_add_rmsnorm,
)
from .backward_linear_add import (
BackwardLinearAdd,
fuse_backward_linear_add,
)
from .backward_linear_scale import (
BackwardLinearScale,
fuse_backward_linear_scale,
)
from .forward_linear_bias_activation import (
ForwardLinearBiasActivation,
fuse_forward_linear_bias_activation,
)
from .forward_linear_bias_add import (
ForwardLinearBiasAdd,
fuse_forward_linear_bias_add,
)
from .forward_linear_scale_add import (
ForwardLinearScaleAdd,
fuse_forward_linear_scale_add,
)
from .userbuffers_backward_linear import (
UserbuffersBackwardLinear,
fuse_userbuffers_backward_linear,
)
from .userbuffers_forward_linear import (
UserbuffersForwardLinear,
fuse_userbuffers_forward_linear,
)
from ..fuser import register_backward_fusion, register_forward_fusion
from .backward_activation_bias import BackwardActivationBias
from .backward_add_rmsnorm import BackwardAddRMSNorm
from .backward_linear_add import BackwardLinearAdd
from .backward_linear_scale import BackwardLinearScale
from .forward_linear_bias_activation import ForwardLinearBiasActivation
from .forward_linear_bias_add import ForwardLinearBiasAdd
from .forward_linear_scale_add import ForwardLinearScaleAdd
from .userbuffers_backward_linear import UserbuffersBackwardLinear
from .userbuffers_forward_linear import UserbuffersForwardLinear
# Register forward fusions
register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops)
register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops)
register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops)
register_forward_fusion(ForwardLinearScaleAdd.fuse_forward_ops)
# Register backward fusions
register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops)
register_backward_fusion(BackwardLinearAdd.fuse_backward_ops)
register_backward_fusion(BackwardLinearScale.fuse_backward_ops)
register_backward_fusion(BackwardActivationBias.fuse_backward_ops)
register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops)
......@@ -53,8 +53,8 @@ class BackwardActivationBias(FusedOperation):
]:
# Get basic operation contexts
activation_op_ctx = basic_op_ctxs[0]
bias_op_ctx = basic_op_ctxs[1]
bias_op_ctx = basic_op_ctxs[0]
activation_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass
(act_input,) = activation_op_ctx.saved_tensors
......@@ -79,68 +79,59 @@ class BackwardActivationBias(FusedOperation):
# Clear activation input tensor
clear_tensor_data(act_input)
return dx, [(), (db,)], [(), ()]
return dx, [(db,), ()], [(), ()]
def fuse_backward_activation_bias(
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dact + dbias + quantize
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe : Recipe, optional
Used quantization recipe
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if recipe is None:
return ops
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
*,
recipe: Optional[Recipe] = None,
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
recipe : Recipe, optional
Quantization recipe.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if recipe is None:
return ops
# Scan through ops, fusing if possible
out = []
window, ops = ops[:3], ops[3:]
while len(window) == 3:
if (
isinstance(window[2], _fusible_activations)
and isinstance(window[1], Bias)
and window[0].get_grad_output_quantizer() is not None
):
# Construct fused op if window matches pattern
op = BackwardActivationBias(bias=window[1], activation=window[2])
window = [window[0], op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]
# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is a supported activation
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, _fusible_activations):
continue
# Check if second op is bias
op, _ = ops[0]
if not isinstance(op, Bias):
continue
# Check if third op has a grad input quantizer
op, _ = ops[1]
if not op.num_quantizers("backward") > 0:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardActivationBias(
activation=window[0][0],
bias=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment