Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
......@@ -16,7 +16,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from transformer_engine.pytorch.tensor.utils import is_custom
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
......@@ -57,7 +57,7 @@ from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import (
from ..quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Quantizer,
......@@ -67,10 +67,15 @@ from ..tensor.quantized_tensor import (
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
general_gemm,
......@@ -167,6 +172,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group)
weight_requires_grad = weight.requires_grad
......@@ -203,13 +211,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
experimental = is_experimental(input_quantizer)
custom = is_custom(input_quantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
)
# Apply normalization
......@@ -255,8 +263,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
# experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
# custom recipe doesn't need to support quantized AG
if not with_quantized_norm and not custom:
ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
......@@ -285,12 +293,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight
quantized_weight = False
is_weight_param_quantized = False
if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorStorage)
is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage)
# Configure quantizer
if weight_quantizer is not None:
# If weight is already quantized, no need to set quantizer states
if is_weight_param_quantized:
weight_quantizer = weight._quantizer
elif weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
# Get quantized weight
......@@ -422,10 +433,6 @@ class _LayerNormLinear(torch.autograd.Function):
):
ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
mark_activation_offload(inputmat, mu, rsigma, ln_out)
......@@ -438,42 +445,21 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group,
mu,
rsigma,
weightmat if quantized_weight else None,
weightmat if fp8 and not is_weight_param_quantized else None,
ln_out if weight.requires_grad else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
# Do not offload weights and biases
weight.offloading_activation = False
weightmat.offloading_activation = False
if bias is not None:
bias.offloading_activation = False
ln_weight.offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
if cpu_offloading:
mark_not_offload(
weightmat,
weight,
bias,
ln_weight,
ln_bias,
)
if (
fine_grained_activation_offloading
and weight.requires_grad
and fuse_wgrad_accumulation
):
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.has_grad_added_to_main_grad:
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
......@@ -495,7 +481,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.requires_dgrad = inp_requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight
ctx.is_weight_param_quantized = is_weight_param_quantized
if fuse_wgrad_accumulation and weight.requires_grad:
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
......@@ -579,6 +565,7 @@ class _LayerNormLinear(torch.autograd.Function):
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
......@@ -599,7 +586,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fsdp_shapes,
mu,
rsigma,
weight if ctx.fp8 and ctx.quantized_weight else None,
weight if ctx.fp8 and not ctx.is_weight_param_quantized else None,
ln_out,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
......
......@@ -18,7 +18,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.tensor.utils import is_experimental
from transformer_engine.pytorch.tensor.utils import is_custom
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
......@@ -70,8 +70,13 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
......@@ -106,6 +111,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
}
if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
......@@ -121,6 +127,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, tex.dbias_dsilu),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
......@@ -142,6 +149,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
}
raise NotImplementedError(f"Unhandled recipe type {recipe}")
......@@ -206,6 +214,7 @@ class _LayerNormMLP(torch.autograd.Function):
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
activation_params: Optional[dict],
normalization: str,
ub_overlap_ag: bool,
ub_overlap_rs: bool,
......@@ -238,6 +247,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if is_cpu_offload_enabled():
start_offload(inputmat)
tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
......@@ -275,13 +286,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
experimental = is_experimental(fc1_input_quantizer)
custom = is_custom(fc1_input_quantizer)
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental
and not custom
)
# Apply normalization
......@@ -321,8 +332,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer = None
if fp8 or debug:
quantizer = fc1_input_quantizer
# experimental recipe doesn't need to support quantized AG
if not with_quantized_norm and not experimental:
# custom recipe doesn't need to support quantized AG
if not with_quantized_norm and not custom:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
......@@ -354,8 +365,17 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
# No need to set the quantizer states if weights are already quantized
if isinstance(fc1_weight, QuantizedTensorStorage):
fc1_weight_quantizer = fc1_weight._quantizer
elif fc1_weight_quantizer is not None:
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
if isinstance(fc2_weight, QuantizedTensorStorage):
fc2_weight_quantizer = fc2_weight._quantizer
elif fc2_weight_quantizer is not None:
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight,
quantizer=fc1_weight_quantizer,
......@@ -447,6 +467,7 @@ class _LayerNormMLP(torch.autograd.Function):
# ACTIVATION - sometimes activation is fused with the GEMM above.
fc1_out_without_bias = None
act_params = activation_params or {}
if bias_gelu_fusion:
fc1_out = None
......@@ -456,7 +477,7 @@ class _LayerNormMLP(torch.autograd.Function):
act_out, _, fc1_out, _ = fc1_outputs
elif debug:
fc1_out, *_ = fc1_outputs
act_out = activation_func(fc1_out, None)
act_out = activation_func(fc1_out, None, **act_params)
act_out = fc2_input_quantizer(act_out)
else:
fc1_out, *_ = fc1_outputs
......@@ -464,19 +485,19 @@ class _LayerNormMLP(torch.autograd.Function):
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise
act_out = activation_func(fc1_out, None)
act_out = activation_func(fc1_out, None, **act_params)
act_out = tex.quantize(act_out, fc2_input_quantizer)
elif recipe.custom():
# tex.quantize does not support custom quantizers
act_out = activation_func(fc1_out, None)
act_out = activation_func(fc1_out, None, **act_params)
act_out = fc2_input_quantizer(act_out)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params)
else:
if fp8_calibration:
act_out = activation_func(fc1_out, None)
act_out = activation_func(fc1_out, None, **act_params)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params)
if not is_grad_enabled:
clear_tensor_data(fc1_out)
......@@ -540,13 +561,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Cache state for backward pass
if is_grad_enabled:
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorStorage):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensorStorage):
fc2_weight_final.update_usage(columnwise_usage=True)
if cpu_offloading:
mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
......@@ -577,6 +591,18 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data(act_out)
act_out = None
if cpu_offloading:
mark_not_offload(
ln_weight,
ln_bias,
fc1_weight_final,
fc1_weight,
fc1_bias,
fc2_weight_final,
fc2_weight,
fc2_bias,
)
tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
ln_weight,
......@@ -631,6 +657,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.device = device
ctx.activation_dtype = activation_dtype
ctx.activation = activation
ctx.activation_params = activation_params
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......@@ -1017,6 +1044,7 @@ class _LayerNormMLP(torch.autograd.Function):
# --------------------------------------------------
# bias computation
act_params = ctx.activation_params or {}
fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False
if ctx.fc1_grad_output_quantizer is not None:
......@@ -1030,7 +1058,7 @@ class _LayerNormMLP(torch.autograd.Function):
dact = ctx.fc1_grad_output_quantizer(dact)
elif ctx.debug:
dact_func = _act_func(ctx.activation)[1]
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None)
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params)
fc1_bias_grad = dact.sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
elif (
......@@ -1042,7 +1070,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[2]
fc1_bias_grad, dact = dbias_dact_quantize_func(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer
fc2_dgrad,
fc1_out.to(ctx.activation_dtype),
ctx.fc1_grad_output_quantizer,
**act_params,
) # quantize bgrad gelu fused
else:
# Fusion: gemm + gelu,
......@@ -1051,7 +1082,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[1]
dact = activation_func_bwd(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), None
fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params
) # activation in high precision
if ctx.fp8:
......@@ -1429,6 +1460,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # activation
None, # activation_params
None, # normalization
None, # ub_overlap_ag
None, # ub_overlap_rs
......@@ -1464,7 +1496,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
'silu', 'swiglu', and 'clamped_swiglu'.
activation_params : dict, default = `None`
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
......@@ -1565,6 +1601,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
bias: bool = True,
normalization: str = "LayerNorm",
activation: str = "gelu",
activation_params: Optional[dict] = None,
output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
params_dtype: Optional[torch.dtype] = None,
......@@ -1592,6 +1629,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
self.use_bias = bias
self.activation = activation
self.activation_params = activation_params
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output
......@@ -1671,7 +1709,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias = None
# FC1 init
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]:
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]:
fc1_output_features = 2 * self.size_per_partition
else:
fc1_output_features = self.size_per_partition
......@@ -1926,6 +1964,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
......@@ -2055,6 +2094,19 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias)
fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32
act_params = self.activation_params or {}
# Default params for clamped_swiglu in Transformer Engine
clamped_swiglu_limit, clamped_swiglu_alpha = act_params.get("limit", 7.0), act_params.get(
"alpha", 1.702
)
def _clamped_swiglu(x, limit, alpha):
x_glu, x_linear = x.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y = out_glu * (x_linear + 1)
return y
activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
......@@ -2069,6 +2121,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
* x.chunk(2, -1)[1],
"silu": torch.nn.functional.silu,
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"clamped_swiglu": lambda x: _clamped_swiglu(
x, clamped_swiglu_limit, clamped_swiglu_alpha
),
}
if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
......@@ -2240,7 +2295,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
if not self.need_backward_dw():
return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
......
......@@ -59,7 +59,7 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import (
from ..quantized_tensor import (
QuantizedTensor,
QuantizedTensorStorage,
Quantizer,
......@@ -68,9 +68,14 @@ from ..tensor.quantized_tensor import (
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.utils import is_experimental
from ..tensor.utils import is_custom
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpu_offload import (
is_cpu_offload_enabled,
start_offload,
mark_not_offload,
mark_activation_offload,
)
from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["Linear"]
......@@ -156,8 +161,8 @@ class _Linear(torch.autograd.Function):
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG
# experimental recipe check
experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer)
# custom recipe check
custom = is_custom(input_quantizer) or is_custom(weight_quantizer)
# ------------------------------------------------------
# Prepare input tensor
......@@ -181,7 +186,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorStorage) and not experimental:
if not isinstance(inputmat, QuantizedTensorStorage) and not custom:
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
......@@ -232,6 +237,9 @@ class _Linear(torch.autograd.Function):
else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat
if is_cpu_offload_enabled():
start_offload(inputmat)
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------
# Input tensor is ready for GEMM...
......@@ -243,7 +251,8 @@ class _Linear(torch.autograd.Function):
weightmat = weight
if fp8 or debug:
# Configure quantizer
if weight_quantizer is not None:
# No need to set the quantizer states if weight is already quantized
if weight_quantizer is not None and not isinstance(weight, QuantizedTensor):
columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage:
columnwise_usage = (
......@@ -251,7 +260,9 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase()
)
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
# If weight is already quantized, no need to set quantizer states
weight_quantizer = weight._quantizer
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
......@@ -392,11 +403,6 @@ class _Linear(torch.autograd.Function):
if backward_needs_input:
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None:
mark_activation_offload(saved_inputmat)
......@@ -442,12 +448,7 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx.weight_object = weight
# Do not offload weights and biases
weight.offloading_activation = False
weightmat.offloading_activation = False
if bias is not None:
bias.offloading_activation = False
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat,
......@@ -477,7 +478,7 @@ class _Linear(torch.autograd.Function):
ctx.main_grad_func = lambda: weight.main_grad
ctx.debug = debug
ctx.experimental = experimental
ctx.custom = custom
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = bias is not None
......@@ -647,7 +648,7 @@ class _Linear(torch.autograd.Function):
if isinstance(inputmat, QuantizedTensorStorage):
# Input tensor is already quantized
pass
elif ctx.debug or ctx.experimental:
elif ctx.debug or ctx.custom:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
......
......@@ -13,7 +13,7 @@ from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..quantization import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.quantized_tensor import QuantizedTensorStorage
from ..quantized_tensor import QuantizedTensorStorage
from ..utils import canonicalize_dtype
......
......@@ -21,7 +21,7 @@ from ...module.base import (
get_ub,
get_workspace,
)
from ...tensor.quantized_tensor import Quantizer
from ...quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter
......
......@@ -21,7 +21,7 @@ from ...module.base import (
get_workspace,
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import Quantizer
from ...quantized_tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_dequantize, is_quantized_tensor
......
......@@ -28,7 +28,7 @@ from transformer_engine.pytorch.ops.fused import (
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
......
......@@ -372,9 +372,9 @@ class FusedAdam(torch.optim.Optimizer):
"""
dtype = self.name_to_dtype_map[state_name]
if store_param_remainders:
data = torch.zeros_like(param, dtype=torch.int16)
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device)
else:
data = torch.empty_like(param, dtype=dtype)
data = torch.empty(param.shape, dtype=dtype, device=param.device)
if zero_buffer:
data.zero_()
......
......@@ -10,7 +10,7 @@ import torch
import transformer_engine_torch as tex
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
......
......@@ -2,18 +2,29 @@
#
# See LICENSE for license information.
"""Tensor with quantized data"""
"""Pure Python base classes for quantization."""
from __future__ import annotations
from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union
from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import copy
import warnings
import math
import torch
from torch.utils._pytree import tree_map
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor._quantization_helpers import (
_QuantizeFunc,
_IdentityFunc,
_stride_from_shape,
)
_quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)
class QuantizedTensorStorage:
......@@ -30,7 +41,7 @@ class QuantizedTensorStorage:
XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
to behave like regular torch.Tensor (like __torch_dispatch__)."""
_quantizer: Optional[Quantizer]
......@@ -58,6 +69,12 @@ class QuantizedTensorStorage:
f"{self.__class__.__name__} class does not implement update_usage function"
)
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement get_usages function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
......@@ -123,6 +140,7 @@ def prepare_for_saving(
t, t_obj = tensor.prepare_for_saving()
tensor_list.extend(t)
tensor_objects_list.append(t_obj)
return tensor_list, tensor_objects_list
......@@ -309,72 +327,12 @@ class Quantizer(abc.ABC):
"""Returns whether or not given tensor can be quantized"""
return True
class _QuantizeFunc(torch.autograd.Function):
"""Quantize tensor"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor,
quantize_impl: Callable,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return quantize_impl(tensor)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
if init_kwargs is None:
return tensor.detach()
# Construct new tensor if constructor kwargs are provided
ctx.input_dtype = tensor.dtype
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
grad_input = grad_output
if grad_input.dtype == ctx.input_dtype:
grad_input = grad_input.detach()
else:
grad_input = grad_input.to(ctx.input_dtype)
return grad_input, None
def _stride_from_shape(shape: list[int]):
if len(shape) == 0:
return []
rstride = [1]
for d in reversed(shape[1:]):
rstride.append(rstride[-1] * d)
return list(reversed(rstride))
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the quantizer"""
return {
"rowwise": self.rowwise_usage,
"columnwise": self.columnwise_usage,
}
class QuantizedTensor(torch.Tensor):
......@@ -387,7 +345,14 @@ class QuantizedTensor(torch.Tensor):
"""
def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False):
def __new__(
cls,
shape: Iterable[int],
dtype: torch.dtype,
*,
requires_grad: bool = False,
device: Optional[torch.device] = None,
):
# We are assuming only contiguous tensors
stride = _stride_from_shape(shape)
instance = torch.Tensor._make_wrapper_subclass(
......@@ -398,7 +363,7 @@ class QuantizedTensor(torch.Tensor):
dtype=dtype,
layout=torch.strided,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
device=torch.cuda.current_device() if device is None else device,
)
return instance
......@@ -428,6 +393,9 @@ class QuantizedTensor(torch.Tensor):
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement clear function"
)
def __repr__(self, *, tensor_contents=None) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
......@@ -469,6 +437,26 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.copy_.default:
dst = args[0]
src = args[1]
if (
isinstance(dst, QuantizedTensor)
and isinstance(src, QuantizedTensor)
and type(dst._quantizer) is type(src._quantizer)
and set(src.get_usages().keys()) == set(dst.get_usages().keys())
and all(
src.get_usages()[usage] == dst.get_usages()[usage]
for usage in src.get_usages().keys()
)
):
dst_tensors, dst_tensor_obj = dst.prepare_for_saving()
src_tensors, src_tensor_obj = src.prepare_for_saving()
for dst_tensor, src_tensor in zip(dst_tensors, src_tensors):
if dst_tensor is not None:
dst_tensor.copy_(src_tensor, *args[2:], **kwargs)
dst_tensor_obj.restore_from_saved(dst_tensors)
src_tensor_obj.restore_from_saved(src_tensors)
return None
if isinstance(dst, QuantizedTensor):
dst.quantize_(src)
else:
......@@ -481,6 +469,36 @@ class QuantizedTensor(torch.Tensor):
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
# Empty like op
if func == torch.ops.aten.empty_like.default:
tensor = args[0]
device = kwargs.get("device", tensor.device)
requires_grad = kwargs.get("requires_grad", tensor.requires_grad)
pin_memory = kwargs.get("pin_memory", False)
usage = tensor.get_usages()
quantizer_usage = tensor._quantizer.get_usages()
tensor._quantizer.set_usage(**usage)
out = tensor._quantizer.make_empty(
shape=tensor.shape,
dtype=tensor.dtype,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
)
tensor._quantizer.set_usage(**quantizer_usage)
return out
if func == torch.ops.aten.numel.default:
tensor = args[0]
return math.prod(tensor.size())
if func == torch.ops.aten.is_pinned.default:
tensor = args[0]
for t in tensor.get_data_tensors():
if t is not None:
return func(t)
return False # Or error out?
def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
......@@ -495,6 +513,10 @@ class QuantizedTensor(torch.Tensor):
and schema_arg.alias_info.is_write
):
arg.quantize_(new_arg)
elif isinstance(arg, list) and isinstance(new_arg, list):
# Recursively handle update for lists of tensors
for a, na in zip(arg, new_arg):
maybe_update_inplace(a, na, schema_arg)
# In-place op: dequantize, perform op, and quantize
if func._schema.is_mutable:
......@@ -521,6 +543,16 @@ class QuantizedTensor(torch.Tensor):
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def check_if_cpu(arg):
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
assert (
func in _quantized_tensor_cpu_supported_ops
), f"QuantizedTensor on CPU does not support this operation: {func}"
return arg
args = tree_map(check_if_cpu, args)
# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......@@ -551,20 +583,16 @@ class QuantizedTensor(torch.Tensor):
shape: Optional[Iterable[int]] = None,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
data: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Create new quantized tensor
By default, new tensor has the same attributes and underlying
data.
data. This function is intended to create view of tensors.
"""
if shape is None:
shape = data.shape if data is not None else tensor.shape
shape = shape if shape is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata()
if data is not None:
kwargs["data"] = data
return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs)
def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor:
......
......@@ -145,15 +145,25 @@ if __name__ == "__main__":
)
]
# Setup version and requirements.
# Having the framework extension depend on the core lib allows
# us to detect CUDA version dynamically during compilation and
# choose the correct wheel for te core lib.
__version__ = te_version()
cuda_major_version = parse(torch.version.cuda).major
assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}."
te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}"
install_requires = install_requirements() + [te_core]
# Configure package
setuptools.setup(
name=PACKAGE_NAME,
version=te_version(),
version=__version__,
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand},
python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(),
install_requires=install_requires,
tests_require=test_requirements(),
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
......@@ -6,7 +6,7 @@
import torch
from .quantized_tensor import (
from ..quantized_tensor import (
QuantizedTensorStorage,
QuantizedTensor,
Quantizer,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Private helper functions and classes for quantized tensor implementations.
This module contains internal autograd functions and utilities that support
the quantization machinery.
"""
from __future__ import annotations
from typing import Callable, Optional, Tuple, Any, Dict, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
class _QuantizeFunc(torch.autograd.Function):
"""Quantize tensor"""
@staticmethod
def forward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
tensor: torch.Tensor,
quantize_impl: Callable,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return quantize_impl(tensor)
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
class _IdentityFunc(torch.autograd.Function):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@staticmethod
def forward(
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
if init_kwargs is None:
return tensor.detach()
# Construct new tensor if constructor kwargs are provided
ctx.input_dtype = tensor.dtype
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
grad_input = grad_output
if grad_input.dtype == ctx.input_dtype:
grad_input = grad_input.detach()
else:
grad_input = grad_input.to(ctx.input_dtype)
return grad_input, None
def _stride_from_shape(shape: list[int]):
"""Calculate stride from shape for contiguous tensors"""
if len(shape) == 0:
return []
rstride = [1]
for d in reversed(shape[1:]):
rstride.append(rstride[-1] * d)
return list(reversed(rstride))
......@@ -15,11 +15,8 @@ from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
......@@ -220,6 +217,7 @@ class Float8BlockQuantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data"""
if device is None:
......@@ -235,12 +233,13 @@ class Float8BlockQuantizer(Quantizer):
data = None
scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(
scale_shape,
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
)
# Allocate FP8 data transpose if needed
......@@ -248,13 +247,17 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty(
self.get_columnwise_shape(shape), dtype=torch.uint8, device=device
self.get_columnwise_shape(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape,
dtype=torch.float32,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......
......@@ -4,21 +4,18 @@
"""Tensor class with FP8 data"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
from typing import Any, Optional, Tuple, Iterable, Union
import warnings
import torch
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from ..utils import canonicalize_process_group, devices_match
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
from ..constants import dist_group_type
from transformer_engine.pytorch.fp8 import int8_simulation_fp8_tensorwise
......@@ -105,6 +102,7 @@ class Float8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor:
# Canonicalize tensor attributes
......@@ -112,16 +110,19 @@ class Float8Quantizer(Quantizer):
device = torch.device("cuda")
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
transpose_shape = [data.size(-1)] + list(data.shape[:-1])
transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty(
transpose_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......@@ -129,7 +130,7 @@ class Float8Quantizer(Quantizer):
shape=shape,
dtype=dtype,
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device),
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=data_transpose,
......@@ -291,6 +292,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8Tensor:
# Canonicalize tensor attributes
......@@ -298,25 +300,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
device = torch.device("cuda")
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
inner_dim = data.size(-1)
transpose_shape = [shape[-1]] + list(shape[:-1])
data_transpose = torch.empty(
inner_dim,
data.numel() // inner_dim,
transpose_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
return Float8Tensor(
shape=shape,
dtype=dtype,
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device),
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=data_transpose,
......@@ -538,9 +541,36 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
self._transpose = None
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
def make_like(
cls,
tensor: QuantizedTensor,
*,
shape: Optional[Iterable[int]] = None,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
data: Optional[torch.Tensor] = None,
data_transpose: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Create new quantized tensor
By default, new tensor has the same attributes and underlying
data.
# View op
"""
if shape is None and data is not None:
shape = data.shape
new_tensor = super().make_like(
tensor, shape=shape, dtype=dtype, requires_grad=requires_grad
)
if data is not None:
new_tensor._data = data
if data_transpose is not None:
new_tensor._transpose = data_transpose
new_tensor._transpose_invalid = False
return new_tensor
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == aten.view.default:
tensor = args[0]
data = tensor._data
......@@ -559,6 +589,9 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
or out_transpose_shape[1:] != out_shape[:-1]
):
out_transpose = None
else:
view_shape_for_transpose = [out_shape[-1]] + list(out_shape[:-1])
out_transpose = out_transpose.view(*view_shape_for_transpose)
return Float8Tensor(
shape=out_shape,
dtype=tensor.dtype,
......@@ -591,11 +624,37 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[data] + list(args[1:]),
kwargs,
)
return [
Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape)
for split_tensor in func_out
t_func_out = [None] * len(func_out)
# Compute corresponding split of the transpose cache if available
if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
ndim = data.dim()
# Figure out the original split dim
if "dim" in kwargs:
dim_to_split = kwargs["dim"]
else:
dim_to_split = args[2] if len(args) > 2 else 0
# Dimension along which transpose needs to be split
t_dim = 0 if dim_to_split == ndim - 1 else dim_to_split + 1
t_func_out = transpose.__torch_dispatch__(
func,
types,
[transpose, args[1], t_dim],
kwargs,
)
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]
return outs
if func == aten.new_zeros.default:
# create fresh new tensor with zeros.
tensor = args[0]
data = tensor._data
func_out = data.__torch_dispatch__(
......@@ -604,28 +663,82 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)
func_transposed_out = None
if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
size = args[1]
t_shape = [size[-1]] + list(size[:-1])
func_transposed_out = transpose.__torch_dispatch__(
func,
types,
[transpose, t_shape] + list(args[2:]),
kwargs,
)
# deep copy the scale inverse tensor and quantizer as well.
scale_inv = tensor._scale_inv.detach().clone()
quantizer = tensor._quantizer.copy()
out_tensor = Float8Tensor(
data=func_out,
shape=func_out.shape,
dtype=tensor.dtype,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=scale_inv,
data_transpose=func_transposed_out,
quantizer=quantizer,
)
return out_tensor
if func == torch.ops.aten.as_strided.default:
tensor = args[0]
data = tensor._data
# Apply as_strided to the primary uint8 data
func_out = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)
func_transposed_out = None
if tensor._transpose is not None and not tensor._transpose_invalid:
transpose = tensor._transpose
size = args[1]
stride = args[2]
if "storage_offset" in kwargs:
storage_offset = kwargs["storage_offset"]
else:
storage_offset = args[3] if len(args) > 3 else 0
# Shape and strided needed for transpose matrix
t_size = [size[-1]] + list(size[:-1])
t_stride = [stride[-1]] + list(stride[:-1])
func_transposed_out = transpose.__torch_dispatch__(
func,
types,
[transpose, t_size, t_stride, storage_offset] + list(args[4:]),
kwargs,
)
return Float8Tensor.make_like(
tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape
)
if func == torch.ops.aten.detach.default:
return cls.detach(args[0])
if func == torch.ops.aten.clone.default:
return cls.clone(args[0])
if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
# Just copy FP8 attrs if copying between Float8Tensors
if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor):
dst._data.copy_(src._data.detach())
dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size()))
if src._transpose is not None or dst._transpose is not None:
if dst._data is not None:
dst._data.copy_(src._data.detach(), *args[2:], **kwargs)
if dst._scale_inv is not None:
dst._scale_inv.copy_(
src._scale_inv.view(dst._scale_inv.size()), *args[2:], **kwargs
)
if dst._transpose is not None and not dst._transpose_invalid:
if not src._transpose_invalid:
dst._transpose.copy_(src._transpose, *args[2:], **kwargs)
else:
dst._create_transpose()
return dst
elif func in _ops_to_preserve_subclass_in_fsdp2:
......@@ -636,9 +749,105 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
)
else:
pass
return super().__torch_dispatch__(func, types, args, kwargs)
def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride())
module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard
that contains this FP8 tensor.
mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2.
Returns:
shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.(In this case uint8 data tensor)
metadata: Tuple[Any]: Metadata needed for reconstructing the
Float8Tensor after all-gather.
"""
# pylint: disable=unused-argument
# Importing here to avoid circular imports
from transformer_engine.pytorch.distributed import _get_module_fsdp_state
if isinstance(self._quantizer, Float8CurrentScalingQuantizer) and mesh is not None:
# When sharded weight is updated after reduce scattering the gradients in FSDP2,
# we need to do amax reduction across the mesh to make sure all weight shards are
# updated with same scale inverse. Setting the state below in the quantizer will make
# sure that updated Quantized weight tensor have same scale inverse across all shards.
self._quantizer.amax_reduction_group = mesh.get_group()
self._quantizer.with_amax_reduction = True
quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward and so we dont change the quantizer usages which might need
# both rowwise and columnwise usages.
if reshard_after_forward:
training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
sharded_tensors = (self._data,)
metadata = (self._scale_inv, self._fp8_dtype, quantizer)
return sharded_tensors, metadata
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[Float8Tensor] = None,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor.
param_dtype (torch.dtype): high precision dtype of the Float8Tensor.
out (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors
used by the Float8Tensor that was being computed after allgather.
"""
(data,) = all_gather_outputs
(fp8_scale_inv, fp8_dtype, quantizer) = metadata
orig_shape = data.size()
# Quantizer has only columnwise usage set for backward pass
# In Blackwell+ architectures, transpose is not needed at all,
# even if columnwise usage is set. and is going to be handled
# internally in the update_usage method.
if out is not None:
out._data = data
else:
fp8_args = {
"shape": orig_shape,
"dtype": param_dtype,
"fp8_scale_inv": fp8_scale_inv,
"fp8_dtype": fp8_dtype,
"quantizer": quantizer,
"requires_grad": False,
"data": data,
}
out = Float8Tensor(**fp8_args)
out.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
return out, all_gather_outputs
@classmethod
def _make_in_reduce_ex(
cls,
......@@ -756,6 +965,9 @@ class _ViewFunc(torch.autograd.Function):
out_transpose_shape = out_transpose.size()
if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]:
out_transpose = None
else:
view_shape_for_transpose = [shape[-1]] + list(shape[:-1])
out_transpose = out_transpose.view(*view_shape_for_transpose)
return Float8Tensor(
shape=out_shape,
dtype=tensor.dtype,
......@@ -800,6 +1012,9 @@ class _ReshapeFunc(torch.autograd.Function):
out_transpose_shape = out_transpose.size()
if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]:
out_transpose = None
else:
reshape_shape_for_transpose = [shape[-1]] + list(shape[:-1])
out_transpose = out_transpose.reshape(*reshape_shape_for_transpose)
return Float8Tensor(
shape=out_shape,
dtype=tensor.dtype,
......
......@@ -6,22 +6,20 @@
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Any
import warnings
import torch
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from ..constants import MXFP8_BLOCK_SCALING_SIZE
from ..utils import devices_match, round_up_to_nearest_multiple
from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func
from .quantized_tensor import (
QuantizedTensor,
Quantizer,
_IdentityFunc,
)
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
aten = torch.ops.aten
......@@ -92,6 +90,7 @@ class MXFP8Quantizer(Quantizer):
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> MXFP8Tensor:
# Canonicalize tensor attributes
......@@ -107,24 +106,29 @@ class MXFP8Quantizer(Quantizer):
)
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
data = None
scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Allocate FP8 data transpose if needed
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data)
columnwise_data = torch.empty_like(data, pin_memory=pin_memory)
columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
# Construct FP8 tensor
......@@ -301,7 +305,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
memory_format: torch.memory_format = torch.contiguous_format,
) -> MXFP8Tensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
......@@ -317,7 +320,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default:
tensor = args[0]
......@@ -341,9 +343,339 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fp8_dtype=tensor._fp8_dtype,
)
if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1]
if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor):
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
# If not, default to base class behavior.
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None
columnwise_matches = (
src._columnwise_data is not None or dst._columnwise_data is None
)
if rowwise_matches and columnwise_matches:
if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs)
dst._rowwise_scale_inv.copy_(
src._rowwise_scale_inv.detach(), *args[2:], **kwargs
)
if dst._columnwise_data is not None:
dst._columnwise_data.copy_(
src._columnwise_data.detach(), *args[2:], **kwargs
)
dst._columnwise_scale_inv.copy_(
src._columnwise_scale_inv.detach(), *args[2:], **kwargs
)
return dst
# FSDP2 related functions.
if func == aten.split.Tensor:
# This is called if entire model is initialized on CUDA device and
# then splitted. Finally the shard needed by the process is used
# and other splitted shards are discarded.
if "dim" in kwargs:
dim_to_split = kwargs["dim"]
else:
dim_to_split = args[2] if len(args) > 2 else 0
tensor = args[0]
split_size = args[1]
dim0_size = tensor.size(0)
dimlast_size = math.prod(tensor.shape[1:])
if (
dim0_size % split_size != 0
or dim_to_split != 0
or split_size % MXFP8_BLOCK_SCALING_SIZE != 0
or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0
):
# Handle splitting by dequantizing and splitting the hp tensor
return super().__torch_dispatch__(func, types, args, kwargs)
out_data = []
for data in [tensor._rowwise_data, tensor._columnwise_data]:
func_out = (
data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
if data is not None
else None
)
out_data.append(func_out)
scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv]
split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE]
# Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4
padding_multiples = [128, 4]
for scale_inv, scale_split_size, pad_multiple in zip(
scale_invs, split_sizes_for_scale, padding_multiples
):
scale_inv_out = (
scale_inv.__torch_dispatch__(
func,
types,
[scale_inv, scale_split_size] + list(args[2:]),
kwargs,
)
if scale_inv is not None
else None
)
# Pad scale_inv_out to be a multiple of pad_multiple
if scale_inv_out is not None:
current_shape = scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
if pad_dim0 > 0:
scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0))
out_data.append(scale_inv_out)
return [
MXFP8Tensor(
shape=(
splitted_tensor_data[0].size()
if splitted_tensor_data[0] is not None
else splitted_tensor_data[1].size()
),
dtype=tensor.dtype,
rowwise_data=splitted_tensor_data[0],
rowwise_scale_inv=splitted_tensor_data[2],
columnwise_data=splitted_tensor_data[1],
columnwise_scale_inv=splitted_tensor_data[3],
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
for splitted_tensor_data in zip(*out_data)
]
if func == torch.ops.aten.as_strided.default:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op.
shape = args[1]
strides = args[2]
tensor = args[0]
if (
len(shape) != 2
or len(strides) != 2
or strides[1] != 1
or shape[0] != tensor.shape[0]
or shape[1] != tensor.shape[1]
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if func == aten.slice.Tensor:
# FSDP2 needed function.
# We need slicing for the case where some MXFP8 weight shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op.
dim = args[1]
start = args[2]
length = args[3]
tensor = args[0]
if (
dim != 0
or length != tensor.shape[0]
or start != 0
or length % MXFP8_BLOCK_SCALING_SIZE != 0
or start % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
return MXFP8Tensor.make_like(tensor)
if func == aten.new_zeros.default:
rowwise_data = None
columnwise_data = None
rowwise_scale_inv = None
columnwise_scale_inv = None
tensor = args[0]
shape = args[1]
first_dim = math.prod(shape[:-1])
last_dim = shape[-1]
if (
first_dim % MXFP8_BLOCK_SCALING_SIZE != 0
or last_dim % MXFP8_BLOCK_SCALING_SIZE != 0
):
return super().__torch_dispatch__(func, types, args, kwargs)
rowwise_scale_inv_shape = [first_dim, last_dim // MXFP8_BLOCK_SCALING_SIZE]
columnwise_scale_inv_shape = [
first_dim // MXFP8_BLOCK_SCALING_SIZE,
last_dim,
]
if tensor._rowwise_data is not None:
rowwise_data = tensor._rowwise_data.__torch_dispatch__(
func,
types,
[tensor._rowwise_data] + list(args[1:]),
kwargs,
)
rowwise_scale_inv = tensor._rowwise_scale_inv.__torch_dispatch__(
func,
types,
[tensor._rowwise_scale_inv, rowwise_scale_inv_shape] + list(args[2:]),
kwargs,
)
if tensor._columnwise_data is not None:
columnwise_data = tensor._columnwise_data.__torch_dispatch__(
func,
types,
[tensor._columnwise_data] + list(args[1:]),
kwargs,
)
columnwise_scale_inv = tensor._columnwise_scale_inv.__torch_dispatch__(
func,
types,
[tensor._columnwise_scale_inv, columnwise_scale_inv_shape] + list(args[2:]),
kwargs,
)
return MXFP8Tensor(
shape=args[1],
dtype=tensor.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=tensor._quantizer.copy(),
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride()).
module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard
that contains this MXFP8 tensor.
mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2.
Returns:
sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.
metadata: Tuple[Any]: Metadata needed for reconstructing the
MXFP8Tensor after all-gather.
"""
# pylint: disable=unused-argument
from transformer_engine.pytorch.distributed import _get_module_fsdp_state
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
quantizer = self._quantizer.copy()
# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = self._columnwise_scale_inv
shape = self.shape
if rowwise_scale_inv is not None:
# Remove padding from rowwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1])
if rowwise_scale_inv.size(0) != flattened_in_shape0:
rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0]
if columnwise_scale_inv is not None:
# Remove padding from columnwise scale_inv
flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE
if columnwise_scale_inv.size(0) != flattened_in_shape0:
columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0]
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
if reshard_after_forward:
training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
sharded_tensors = (
(self._columnwise_data, columnwise_scale_inv)
if is_backward_pass
else sharded_tensors
)
else:
if quantizer.columnwise_usage:
# If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors += (self._columnwise_data, columnwise_scale_inv)
metadata = (self._fp8_dtype, quantizer)
return sharded_tensors, metadata
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[MXFP8Tensor] = None,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the MXFP8Tensor.
param_dtype (torch.dtype): high precision dtype of the MXFP8Tensor.
out (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors
used by the MXFP8Tensor that was being computed after allgather.
"""
fp8_dtype, quantizer = metadata
rowwise_data, rowwise_scale_inv = (
all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None)
)
columnwise_data, columnwise_scale_inv = (
all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None)
)
# Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise
if rowwise_scale_inv is not None:
# Pad rowwise_scale_inv to be a multiple of [128, 4]
current_shape = rowwise_scale_inv.shape
pad_dim0 = (128 - current_shape[0] % 128) % 128
if pad_dim0 > 0:
rowwise_scale_inv = torch.nn.functional.pad(rowwise_scale_inv, (0, 0, 0, pad_dim0))
if columnwise_scale_inv is not None:
# Pad columnwise_scale_inv to be a multiple of [4, 128]
current_shape = columnwise_scale_inv.shape
pad_dim0 = (4 - current_shape[0] % 4) % 4
if pad_dim0 > 0:
columnwise_scale_inv = torch.nn.functional.pad(
columnwise_scale_inv, (0, 0, 0, pad_dim0)
)
if out is not None:
out._rowwise_data = rowwise_data
out._rowwise_scale_inv = rowwise_scale_inv
out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv
out._quantizer = quantizer
else:
out = MXFP8Tensor(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=quantizer,
)
return out, all_gather_outputs
@classmethod
def _make_in_reduce_ex(
cls,
......@@ -481,10 +813,14 @@ class _ViewFunc(torch.autograd.Function):
shape[i] = d_inferred
break
if shape[-1] != ctx.shape[-1]:
raise RuntimeError(
"MXFP8Tensor does not support reshaping inner dimension "
warnings.warn(
"MXFP8Tensor does not support reshaping inner dimension. "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
"If you are using this for FSDP2 without compiled_autograd_enabled,"
"then ignore this warning. Since this view is not going to be used anywhere. ",
stacklevel=2,
)
return tensor.dequantize().view(*shape)
# Construct new tensor if shape is provided
new_rowwise_data = None
......
......@@ -6,7 +6,7 @@
from __future__ import annotations
from collections.abc import Iterable
import math
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import functools
import torch
......@@ -22,14 +22,15 @@ from ..utils import (
)
from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..quantized_tensor import QuantizedTensor, Quantizer
from ._quantization_helpers import _IdentityFunc
aten = torch.ops.aten
def get_no_random_sign_vector() -> torch.Tensor:
"""Non-random sign vector for Hadamard transform."""
return torch.tensor([1], dtype=torch.float32)
return torch.tensor([1], dtype=torch.float32, device="cuda")
def get_sign_from_vector(vector: torch.Tensor) -> int:
......@@ -41,7 +42,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int:
mask = 0
for i, v in enumerate(vector):
mask |= (v == -1) << i
return mask
return mask.item()
def get_wgrad_sign_vector() -> torch.Tensor:
......@@ -53,6 +54,7 @@ def get_wgrad_sign_vector() -> torch.Tensor:
return torch.tensor(
[1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1],
dtype=torch.float32,
device="cuda",
)
......@@ -81,6 +83,7 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
[1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1],
],
dtype=torch.float32,
device="cuda",
)
* hadamard_scale
)
......@@ -94,9 +97,9 @@ def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor:
signs = get_wgrad_sign_vector()
else:
signs = get_no_random_sign_vector()
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32)
sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device="cuda")
rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension)
return rht_matrix.to(dtype=torch.bfloat16).cuda()
return rht_matrix.to(dtype=torch.bfloat16)
@functools.lru_cache(maxsize=None)
......@@ -262,6 +265,7 @@ class NVFP4Quantizer(Quantizer):
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False,
) -> NVFP4Tensor:
......@@ -285,11 +289,18 @@ class NVFP4Quantizer(Quantizer):
scale_inv = None
amax_rowwise = None
if self.rowwise_usage:
data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device)
data = torch.empty(
self.convert_shape_for_fp4(shape),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device)
scale_inv = torch.empty(
scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
# Allocate per tensor scale inverse. FP32 format.
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device)
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory)
# Allocate FP8 data transpose if needed
columnwise_data = None
......@@ -303,12 +314,15 @@ class NVFP4Quantizer(Quantizer):
self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)),
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device
columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
amax_columnwise = torch.zeros(
1, dtype=torch.float32, device=device, pin_memory=pin_memory
)
amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device)
# Construct FP8 tensor
return NVFP4Tensor(
......@@ -495,6 +509,12 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
return self
raise ValueError("NVFP4Tensor does not support different memory formats!")
def get_usages(self) -> Dict[str, bool]:
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
......@@ -517,16 +537,20 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
)
if tensor._rowwise_data is not None:
rowwise_data = data_init_func(tensor._rowwise_data)
rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise)
rowwise_data = data_init_func(tensor._rowwise_data, *args[1:], **kwargs)
rowwise_scale_inv = scale_inv_init_func(
tensor._rowwise_scale_inv, *args[1:], **kwargs
)
amax_rowwise = torch.zeros_like(tensor._amax_rowwise, *args[1:], **kwargs)
else:
rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None
if tensor._columnwise_data is not None:
columnwise_data = data_init_func(tensor._columnwise_data)
columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise)
columnwise_data = data_init_func(tensor._columnwise_data, *args[1:], **kwargs)
columnwise_scale_inv = scale_inv_init_func(
tensor._columnwise_scale_inv, *args[1:], **kwargs
)
amax_columnwise = torch.zeros_like(tensor._amax_columnwise, *args[1:], **kwargs)
else:
columnwise_data, columnwise_scale_inv, amax_columnwise = (
None,
......
......@@ -14,12 +14,10 @@ from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from ..quantized_tensor import QuantizedTensorStorage
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
......@@ -423,3 +421,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return
return
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
......@@ -12,12 +12,10 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor
......@@ -227,3 +225,12 @@ class Float8TensorStorage(QuantizedTensorStorage):
if not needs_data_transpose:
self._transpose = None
self._transpose_invalid = True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
usages = {"rowwise": self._data is not None}
if is_non_tn_fp8_gemm_supported():
usages["columnwise"] = self._data is not None
else:
usages["columnwise"] = self._transpose is not None and not self._transpose_invalid
return usages
......@@ -13,12 +13,10 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...utils import _empty_tensor
......@@ -256,3 +254,10 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
def get_usages(self) -> Tuple[bool, bool]:
"""Get the usage of the tensor"""
return {
"rowwise": self._rowwise_data is not None,
"columnwise": self._columnwise_data is not None,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment