Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
......@@ -42,7 +42,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Get basic operations
rmsnorm_op = self.basic_ops[1]
rmsnorm_op_ctx = basic_op_ctxs[0]
rmsnorm_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass
x, rstdevs = rmsnorm_op_ctx.saved_tensors
......@@ -53,7 +53,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Check input tensors
dtype = rmsnorm_op_ctx.dtype
extra_grad = basic_op_grad_extra_outputs[1][0]
extra_grad = basic_op_grad_extra_outputs[0][0]
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,))
add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size())
......@@ -77,57 +77,51 @@ class BackwardAddRMSNorm(FusedOperation):
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, [(grad_weight,), ()], [(), ()]
def fuse_backward_add_rmsnorm(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward RMNorm + add
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
return grad_input, [(), (grad_weight,)], [(), ()]
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:2], ops[2:]
while len(window) == 2:
if (
isinstance(window[0], MakeExtraOutput)
and isinstance(window[1], RMSNorm)
and not window[0]._in_place
):
# Construct fused op if window matches pattern
op = BackwardAddRMSNorm(add=window[0], rmsnorm=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, RMSNorm):
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardAddRMSNorm(
rmsnorm=window[0][0],
add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -45,7 +45,7 @@ class BackwardLinearAdd(FusedOperation):
# Get basic operations
linear_op = self.basic_ops[1]
linear_op_ctx = basic_op_ctxs[0]
linear_op_ctx = basic_op_ctxs[1]
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
......@@ -71,7 +71,7 @@ class BackwardLinearAdd(FusedOperation):
accumulate_into_main_grad = False
# Linear backward pass
grad_input = basic_op_grad_extra_outputs[1][0]
grad_input = basic_op_grad_extra_outputs[0][0]
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
......@@ -109,61 +109,60 @@ class BackwardLinearAdd(FusedOperation):
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(grad_weight,), ()], [(), ()]
def fuse_backward_linear_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dgrad GEMM + add
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
return grad_input, [(), (grad_weight,)], [(), ()]
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:2], ops[2:]
while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], MakeExtraOutput) and isinstance(window[1], BasicLinear)):
matches_pattern = False
elif not window[0]._in_place:
# Fused op accumulates grad input in-place
matches_pattern = False
elif window[1].tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication
# after the dgrad GEMM
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = BackwardLinearAdd(backward_add=window[0], linear=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Row tensor-parallelism requires communication after the
# GEMM
continue
# Check if second op is "make extra output"
op, _ = ops[0]
if not isinstance(op, MakeExtraOutput):
continue
if not op._in_place:
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearAdd(
linear=window[0][0],
backward_add=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -45,7 +45,7 @@ class BackwardLinearScale(FusedOperation):
# Get basic operations
linear_op = self.basic_ops[0]
linear_op_ctx = basic_op_ctxs[1]
linear_op_ctx = basic_op_ctxs[0]
scale_op = self.basic_ops[1]
# Saved tensors from forward pass
......@@ -109,58 +109,57 @@ class BackwardLinearScale(FusedOperation):
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(), (grad_weight,)], [(), ()]
def fuse_backward_linear_scale(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dgrad GEMM + constant scale
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
return grad_input, [(grad_weight,), ()], [(), ()]
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:2], ops[2:]
while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale)):
matches_pattern = False
elif window[0].tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication
# after the dgrad GEMM
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = BackwardLinearScale(linear=window[0], scale=window[1])
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is constant scale
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, ConstantScale):
continue
# Check if second op is linear
op, _ = ops[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "column":
# Column tensor-parallelism requires communication after the dgrad GEMM
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = BackwardLinearScale(
scale=window[0][0],
linear=window[1][0],
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -134,62 +134,63 @@ class ForwardLinearBiasActivation(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_bias_activation(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + bias + activation
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:2], ops[2:]
while len(window) == 2:
# Check if window matches pattern
matches_pattern = True
if not (isinstance(window[0], BasicLinear) and isinstance(window[1], Bias)):
matches_pattern = False
elif window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
# the GEMM
matches_pattern = False
elif window[0].weight.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = ForwardLinearBiasActivation(
linear=window[0],
bias=window[1],
activation=None,
)
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-1])
window = window[-1:]
# Adjust window to expected size
out.extend(window[:-2])
window = window[-2:]
while ops and len(window) < 2:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op1, _ = window[0]
if not isinstance(op1, BasicLinear):
continue
if op1.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
if op1.weight.dtype not in (torch.float16, torch.bfloat16):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
continue
# Check if second op is bias
op2, _ = ops[0]
if not isinstance(op2, Bias):
continue
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasActivation(
linear=window[0][0],
bias=window[1][0],
activation=None,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -131,72 +131,63 @@ class ForwardLinearBiasAdd(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while ops:
# Shift window
out.extend(window)
window = [ops[0]]
ops = ops[1:]
def fuse_forward_linear_bias_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + bias + add
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
# Check if first op is linear
if not isinstance(window[0], BasicLinear):
continue
if window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
# the GEMM
continue
linear = window[0]
Returns
-------
ops : list of tuples
Updated forward pass operations
# Check if next op is bias
bias = None
if ops and isinstance(ops[0], Bias):
window.append(ops[0])
ops = ops[1:]
bias = window[-1]
# Check if next op is in-place add extra input
if ops and isinstance(ops[0], AddExtraInput) and ops[0]._in_place:
window.append(ops[0])
ops = ops[1:]
add = window[-1]
else:
continue
"""
# Replace window with fused op
op = ForwardLinearBiasAdd(linear=linear, bias=bias, add=add)
window = [op]
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 2:
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is bias
bias = None
if isinstance(op, Bias):
bias = op
window.extend(ops[:1])
ops = ops[1:]
if len(ops) == 0:
continue
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearBiasAdd(
linear=linear,
bias=bias,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -110,70 +110,66 @@ class ForwardLinearScaleAdd(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))]
def fuse_forward_linear_scale_add(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fuse forward GEMM + scale + add
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window = []
while len(ops) >= 3:
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out = []
window, ops = ops[:3], ops[3:]
while len(window) == 3:
# Check if window matches pattern
matches_pattern = True
if not (
isinstance(window[0], BasicLinear)
and isinstance(window[1], ConstantScale)
and isinstance(window[2], AddExtraInput)
):
matches_pattern = False
elif window[0].tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after
# the GEMM
matches_pattern = False
elif not window[2]._in_place:
# Fused op accumulates output in-place
matches_pattern = False
if matches_pattern:
# Construct fused op if window matches pattern
op = ForwardLinearScaleAdd(
linear=window[0],
scale=window[1],
add=window[2],
)
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]
# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
# Return list of ops
out.extend(window)
# Check if first op is linear
window, ops = ops[:1], ops[1:]
op, _ = window[0]
if not isinstance(op, BasicLinear):
continue
if op.tensor_parallel_mode == "row":
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear = op
op, _ = ops[0]
# Check if next op is constant scale
if not isinstance(op, ConstantScale):
continue
scale = op
window.extend(ops[:1])
ops = ops[1:]
op, _ = ops[0]
# Check if next op is in-place add extra input
if not isinstance(op, AddExtraInput):
continue
if not op._in_place:
continue
add = op
window.extend(ops[:1])
ops = ops[1:]
# Replace window with fused op
op = ForwardLinearScaleAdd(
linear=linear,
scale=scale,
add=add,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out.extend(window)
out.extend(ops)
return out
return out
......@@ -503,7 +503,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations
idx = self._op_idxs["linear"]
linear_op = self.basic_ops[idx]
linear_op_ctx = basic_op_ctxs[-1]
linear_op_ctx = basic_op_ctxs[0]
bias_op = None
if self._op_idxs["bias"] is not None:
idx = self._op_idxs["bias"]
......@@ -578,99 +578,84 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params[self._op_idxs["linear"]] = (grad_weight,)
if bias_op is not None:
grad_params[self._op_idxs["bias"]] = (grad_bias,)
grad_params.reverse()
grad_extra_inputs = [() for _ in range(len(self.basic_ops))]
return grad_input, grad_params, grad_extra_inputs
@staticmethod
def fuse_backward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for backward pass.
def fuse_userbuffers_backward_linear(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Substitute linear operations with Userbuffers implementation
Parameters
----------
ops : list of FusibleOperation
Backward pass operations.
recipe : Recipe, optional
Quantization recipe.
Parameters
----------
ops : list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of FusibleOperation
Updated backward pass operations
Returns
-------
ops : list of tuples
Updated backward pass operations
"""
"""
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Sliding window in list of ops
window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[-1][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.insert(0, ops[-1])
ops = ops[:-1]
return window[0][0]
# Scan through ops in reverse order, fusing if possible
out_reversed = []
while ops:
out_reversed.extend(reversed(window))
window.clear()
# Check if next op is linear
next_op = pop_next_op()
if not isinstance(next_op, BasicLinear):
continue
linear = next_op
if linear._userbuffers_options is None:
continue
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias):
bias = pop_next_op()
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter):
reduce_scatter = pop_next_op()
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
# Scan through ops, fusing if possible
out = []
window = []
while ops:
# Shift window
out.extend(window)
window, ops = ops[:1], ops[1:]
# Check if first op is linear
if not isinstance(window[0], BasicLinear):
continue
if reduce_scatter.process_group_size == 1:
linear = window[0]
if linear._userbuffers_options is None:
continue
# Replace window with fused op
op = UserbuffersBackwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Return list of ops
out_reversed.extend(reversed(window))
out = out_reversed
out.reverse()
return out
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias):
bias, ops = ops[0], ops[1:]
window.append(bias)
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter):
reduce_scatter, ops = ops[0], ops[1:]
window.append(reduce_scatter)
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue
if reduce_scatter.process_group_size == 1:
continue
# Replace window with fused op
op = UserbuffersBackwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
window = [op]
# Return list of ops
out.extend(window)
return out
......@@ -369,93 +369,79 @@ class UserbuffersForwardLinear(FusedOperation):
return output, [() for _ in range(len(self.basic_ops))]
@staticmethod
def fuse_forward_ops(
ops: list[FusibleOperation],
**unused, # pylint: disable=unused-argument
) -> list[FusibleOperation]:
"""Apply operation fusion for forward pass.
def fuse_userbuffers_forward_linear(
ops: list[tuple[FusibleOperation, list[int]]],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Substitute linear operations with Userbuffers implementation
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated forward pass operations
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
"""
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Sliding window in list of ops
window = []
def peek_next_op() -> Optional[FusibleOperation]:
"""Get next op in list of ops"""
nonlocal ops
if not ops:
return None
return ops[0][0]
def pop_next_op() -> FusibleOperation:
"""Remove next op from list of ops and add to sliding window"""
nonlocal ops, window
window.append(ops[0])
ops = ops[1:]
return window[-1][0]
# Scan through ops, fusing if possible
out = []
while ops:
out.extend(window)
window.clear()
"""
# Check if next op is linear
next_op = pop_next_op()
if not isinstance(next_op, BasicLinear):
continue
linear = next_op
if linear._userbuffers_options is None:
continue
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias):
bias = pop_next_op()
# Scan through ops, fusing if possible
out = []
window = []
while ops:
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter):
reduce_scatter = pop_next_op()
# Shift window
out.extend(window)
window, ops = ops[:1], ops[1:]
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
# Check if first op is linear
if not isinstance(window[0], BasicLinear):
continue
if reduce_scatter.process_group_size == 1:
linear = window[0]
if linear._userbuffers_options is None:
continue
# Replace window with fused op
op = UserbuffersForwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window]
window = [(op, basic_op_idxs)]
# Check if next op is bias
bias = None
if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias):
bias, ops = ops[0], ops[1:]
window.append(bias)
# Check if next op is reduce-scatter
reduce_scatter = None
if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter):
reduce_scatter, ops = ops[0], ops[1:]
window.append(reduce_scatter)
# Check for invalid combinations
if reduce_scatter is None:
if linear.tensor_parallel_mode is None:
continue
if linear.tensor_parallel_size == 1:
continue
if linear.tensor_parallel_mode == "row" and bias is not None:
continue
else:
if linear.tensor_parallel_mode is not None:
continue
if reduce_scatter.process_group_size == 1:
continue
# Replace window with fused op
op = UserbuffersForwardLinear(
linear=linear,
bias=bias,
reduce_scatter=reduce_scatter,
)
window = [op]
# Return list of ops
out.extend(window)
return out
# Return list of ops
out.extend(window)
return out
......@@ -5,33 +5,20 @@
"""Manager class for a pipeline of fusible operations."""
from __future__ import annotations
from collections.abc import Callable, Iterable
from typing import Any, Optional
from collections.abc import Callable, Iterable, Sequence
import itertools
from typing import Any, Optional, TypeAlias
import torch
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from transformer_engine.pytorch.ops.op import (
from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling
from ..quantized_tensor import prepare_for_saving, restore_from_saved
from .op import (
BasicOperation,
FusibleOperation,
FusedOperation,
OperationContext,
)
from transformer_engine.pytorch.ops.fused import (
fuse_backward_activation_bias,
fuse_backward_add_rmsnorm,
fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation,
fuse_forward_linear_bias_add,
fuse_forward_linear_scale_add,
fuse_userbuffers_backward_linear,
fuse_userbuffers_forward_linear,
)
from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]:
......@@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool:
return _is_graph_capturing_function()
# Type alias for a function that may perform operation fusion
OperationFusionFunction: TypeAlias = (
"Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]"
)
class _OperationFuserAutogradFunction(torch.autograd.Function):
"""Autograd function for a pipeline of operations
......@@ -241,7 +234,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx = grad_output
grad_params = [None for _ in range(len(basic_ops))]
grad_extra_inputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in backward_ops:
for op, basic_op_idxs in reversed(backward_ops):
# Stop if no more gradients are required
if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
......@@ -315,6 +308,10 @@ class OperationFuser:
"""
# Functions to perform operation fusion
forward_fusion_functions: list[OperationFusionFunction] = []
backward_fusion_functions: list[OperationFusionFunction] = []
def __init__(
self,
ops: list[FusibleOperation],
......@@ -334,7 +331,7 @@ class OperationFuser:
self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops)
self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs)
# Ops for forward and backward pass, will be populated in fuse_ops
# Ops for forward and backward pass, will be populated in maybe_fuse_ops
self._forward_ops: list[tuple[FusibleOperation, list[int]]]
self._backward_ops: list[tuple[FusibleOperation, list[int]]]
......@@ -349,31 +346,48 @@ class OperationFuser:
self._flat_basic_op_params = sum(self._basic_op_params, [])
@classmethod
def _fuse_forward_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], # pylint: disable=unused-argument
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in forward pass"""
ops = fuse_userbuffers_forward_linear(ops)
ops = fuse_forward_linear_bias_add(ops)
ops = fuse_forward_linear_bias_activation(ops)
ops = fuse_forward_linear_scale_add(ops)
return ops
@classmethod
def _fuse_backward_ops(
def _fuse_ops(
cls,
ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: Sequence[BasicOperation],
fusion_funcs: Iterable[OperationFusionFunction],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Attempt to fuse operations in backward pass"""
ops = fuse_userbuffers_backward_linear(ops)
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_linear_scale(ops)
ops = fuse_backward_activation_bias(ops, recipe)
ops = fuse_backward_add_rmsnorm(ops)
return ops
"""Apply operation fusions"""
# Apply op fusions
fused_ops = list(basic_ops)
for func in fusion_funcs:
fused_ops = func(fused_ops, recipe=recipe)
def raise_mismatch_error() -> None:
"""Throw error indicating invalid op fusion"""
raise RuntimeError(
"Found mismatch after fusing operations "
f"(basic_ops={[o.__class__.__name__ for o in basic_ops]}, "
f"fused_ops={[o.__class__.__name__ for o in fused_ops]})"
)
# Determine basic op indices corresponding to each op
out = []
idx = 0
for op in fused_ops:
if isinstance(op, FusedOperation):
idxs = []
for basic_op in op.basic_ops:
if basic_op is not basic_ops[idx]:
raise_mismatch_error()
idxs.append(idx)
idx += 1
out.append((op, idxs))
else:
if op is not basic_ops[idx]:
raise_mismatch_error()
out.append((op, [idx]))
idx += 1
if idx != len(basic_ops):
raise_mismatch_error()
return out
def maybe_fuse_ops(
self,
......@@ -424,12 +438,16 @@ class OperationFuser:
op.pre_first_fuser_forward()
# Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
backward_ops = list(reversed(forward_ops[first_op_requiring_backward:]))
# Fuse ops
self._forward_ops = self._fuse_forward_ops(forward_ops, recipe)
self._backward_ops = self._fuse_backward_ops(backward_ops, recipe)
self._forward_ops = OperationFuser._fuse_ops(
self._basic_ops,
OperationFuser.forward_fusion_functions,
recipe=recipe,
)
self._backward_ops = OperationFuser._fuse_ops(
self._basic_ops,
OperationFuser.backward_fusion_functions,
recipe=recipe,
)
# Save current fusion params
self.recipe_type, self.first_op_requiring_backward = fusion_params
......@@ -491,3 +509,55 @@ class OperationFuser:
*extra_inputs,
)
return forward_func(*args)
def register_forward_fusion(
op_fusion_func: OperationFusionFunction,
prepend: bool = False,
) -> None:
"""Register function to perform operation fusion for forward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if prepend:
OperationFuser.forward_fusion_functions.insert(0, op_fusion_func)
else:
OperationFuser.forward_fusion_functions.append(op_fusion_func)
def register_backward_fusion(
op_fusion_func: OperationFusionFunction,
prepend: bool = False,
) -> None:
"""Register function to perform operation fusion for backward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if prepend:
OperationFuser.backward_fusion_functions.insert(0, op_fusion_func)
else:
OperationFuser.backward_fusion_functions.append(op_fusion_func)
......@@ -123,7 +123,7 @@ class FusedSGD(Optimizer):
self.set_grad_none = set_grad_none
if self.set_grad_none is not None:
warnings.warn(
"set_grad_none kwarg in FusedAdam constructor is deprecated. "
"set_grad_none kwarg in FusedSGD constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead.",
DeprecationWarning,
)
......@@ -147,7 +147,7 @@ class FusedSGD(Optimizer):
if set_to_none is not None and set_to_none != self.set_grad_none:
raise ValueError(
f"Called zero_grad with set_to_none={set_to_none}, "
f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}"
f"but FusedSGD was initialized with set_grad_none={self.set_grad_none}"
)
set_to_none = self.set_grad_none
if set_to_none is None:
......
......@@ -69,7 +69,9 @@ class QuantizedTensorStorage:
f"{self.__class__.__name__} class does not implement get_usages function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function"
......@@ -115,11 +117,18 @@ class QuantizedTensorStorage:
warnings.warn("Quantizer is being updated, this may affect model behavior")
self._quantizer = quantizer
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data from another QuantizedTensorStorage."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement copy_from_storage function"
)
def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorStorage],
) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]]
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
list[Optional[QuantizedTensorStorage]],
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
......@@ -144,7 +153,10 @@ def restore_from_saved(
return_saved_tensors: bool = False,
) -> (
list[Optional[torch.Tensor | QuantizedTensorStorage]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]]
| tuple[
list[Optional[torch.Tensor | QuantizedTensorStorage]],
list[Optional[torch.Tensor]],
]
):
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects = []
......
......@@ -11,7 +11,11 @@ from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Recipe,
)
from ..utils import canonicalize_process_group, devices_match
from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func
from ..quantized_tensor import QuantizedTensor, Quantizer
......@@ -155,6 +159,10 @@ class Float8Quantizer(Quantizer):
amin, amax = tensor.aminmax()
self.amax.copy_(torch.max(-amin, amax))
def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for Float8 1D blockwise quantization."""
return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1])
def create_tensor_from_data(
self,
data: torch.Tensor,
......@@ -409,6 +417,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self,
)
def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for Float8 1D blockwise quantization."""
return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1])
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
if tensor.dtype != torch.float32:
......@@ -770,7 +782,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
kwargs,
)
return Float8Tensor.make_like(
tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape
tensor,
data=func_out,
data_transpose=func_transposed_out,
shape=func_out.shape,
)
if func == torch.ops.aten.detach.default:
......
......@@ -164,6 +164,49 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
pass
def get_scale_shape(
self,
shape: Iterable[int],
columnwise: bool,
) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For MXFP8 1D blockwise quantization, blocksize is 32
Swizzle kernel will be performed before GEMM to suit the need of CuBLAS.
CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
if columnwise:
# Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]]
# with padding to multiples of [4, 128]
return (
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
)
# Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE]
# with padding to multiples of [128, 4]
return (
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
)
def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]:
"""Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization."""
return rowwise_data_shape
def create_tensor_from_data(
self,
data: torch.Tensor,
......@@ -704,7 +747,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape),
quantizer=self._quantizer,
with_gemm_swizzled_scales=False,
)
......
......@@ -341,7 +341,10 @@ class NVFP4Quantizer(Quantizer):
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
columnwise_scale_shape,
dtype=torch.uint8,
device=device,
pin_memory=pin_memory,
)
amax_columnwise = torch.zeros(
1, dtype=torch.float32, device=device, pin_memory=pin_memory
......
......@@ -7,3 +7,4 @@ from .float8_tensor_storage import Float8TensorStorage # noqa: F401
from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401
from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401
from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401
from .grouped_tensor import GroupedTensor # noqa: F401
......@@ -74,6 +74,24 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if t is not None:
t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another Float8BlockwiseQTensorStorage."""
if not isinstance(src, Float8BlockwiseQTensorStorage):
raise TypeError("copy_from_storage expects Float8BlockwiseQTensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
if self._is_2D_scaled != src._is_2D_scaled:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......
......@@ -104,6 +104,24 @@ class Float8TensorStorage(QuantizedTensorStorage):
t.data = _empty_tensor()
self._transpose_invalid = True
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another Float8TensorStorage."""
if not isinstance(src, Float8TensorStorage):
raise TypeError("copy_from_storage expects Float8TensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
def _copy_optional(
dst: Optional[torch.Tensor],
src_tensor: Optional[torch.Tensor],
):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._data, src._data)
_copy_optional(self._transpose, src._transpose)
_copy_optional(self._scale_inv, src._scale_inv)
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Grouped tensor class for handling collections of tensors with different shapes"""
from __future__ import annotations
from typing import Optional, Tuple, List, Union
import math
import torch
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
from ..mxfp8_tensor import MXFP8Tensor
from ..nvfp4_tensor import NVFP4Tensor
from ..float8_tensor import Float8Tensor
from ..float8_blockwise_tensor import Float8BlockwiseQTensor
from .float8_tensor_storage import Float8TensorStorage
from .mxfp8_tensor_storage import MXFP8TensorStorage
from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .nvfp4_tensor_storage import NVFP4TensorStorage
class GroupedTensor:
"""
EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE.
Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode.
Shape Representation:
- logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors
are flattened to 2D and stacked together (REQUIRED)
+ When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N)
+ When varying_first_dim(): [~sum_of_first_dims, N] where N is common
+ When varying_last_dim(): [M, ~sum_of_last_dims] where M is common
+ When varying_both_dims(): [1, total_elements] (fully flattened)
- first_dims and last_dims are OPTIONAL (None if dimension is uniform)
+ None first_dims: all tensors have the same first dimension
+ None last_dims: all tensors have the same last dimension
+ Both None: all tensors have identical shapes
+ Both set: each tensor has unique shape (first_dims[i], last_dims[i])
Data Layout:
- ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.)
- logical_shape provides the conceptual 2D interpretation
- All data is stored on device in contiguous layout
Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode.
"""
def __init__(
self,
num_tensors: int,
shape: List[Tuple[int, int]],
quantizer: Optional[Quantizer] = None,
dtype: Optional[torch.dtype] = None,
data: Optional[torch.Tensor] = None,
columnwise_data: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
columnwise_scale_inv: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
columnwise_amax: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
first_dims: Optional[torch.Tensor] = None,
last_dims: Optional[torch.Tensor] = None,
tensor_offsets: Optional[torch.Tensor] = None,
offsets: Optional[List[int]] = None,
scale_inv_offsets: Optional[List[int]] = None,
columnwise_scale_inv_offsets: Optional[List[int]] = None,
logical_shape: Optional[Tuple[int, int]] = None,
) -> None:
"""
Initialize a GroupedTensor.
Args:
num_tensors: Number of tensors in the group
shape: 2D shape of each tensor (len num_tensors)
quantizer: Quantizer for the grouped tensor
data: Row-wise data buffer (1D flattened)
columnwise_data: Column-wise data buffer (1D flattened)
scale_inv: Row-wise scale inverse buffer
columnwise_scale_inv: Column-wise scale inverse buffer
amax: Row-wise amax buffer
columnwise_amax: Column-wise amax buffer
scale: Scale buffer (for FP8-DS only)
first_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
last_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform)
offsets: Vector of integer offsets for each tensor.
logical_shape: 2D tuple representing conceptual shape
"""
self.num_tensors = num_tensors
self.quantizer = quantizer
self.shape = shape
self.dtype = (
dtype if dtype is not None else torch.float32
) # Default to float32 if not provided
# Data buffers
self.data = data
self.columnwise_data = columnwise_data
self.scale_inv = scale_inv
self.columnwise_scale_inv = columnwise_scale_inv
self.amax = amax
self.columnwise_amax = columnwise_amax
self.scale = scale
# For convenient indexing for python GroupedTensor API.
self.scale_inv_offsets = scale_inv_offsets
self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets
# Shape information (OPTIONAL - None if dimension is uniform across all tensors)
# first_dims[i] = first dimension of tensor i (None if all tensors have same first dim)
# last_dims[i] = last dimension of tensor i (None if all tensors have same last dim)
self.first_dims = (
first_dims # Device pointer to int64_t array of length num_tensors (or None)
)
self.last_dims = (
last_dims # Device pointer to int64_t array of length num_tensors (or None)
)
# Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape())
# tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1)
# Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size
# If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions)
self.tensor_offsets = (
tensor_offsets # Device pointer to int64_t array of length num_tensors (or None)
)
self.offsets = offsets # Vector of integer offsets for each tensor.
# Logical shape: conceptual 2D shape of the grouped data (REQUIRED)
# Represents how the 1D flattened data should be interpreted as 2D
# Always 2D with positive dimensions
self.logical_shape = logical_shape if logical_shape is not None else (0, 0)
# Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor.
# Used as a convenience.
self.quantized_tensors = None
def has_data(self) -> bool:
"""
Check if the tensor has row-wise data.
Returns:
True if data buffer is initialized, False otherwise
"""
return self.data is not None
def has_columnwise_data(self) -> bool:
"""
Check if the tensor has column-wise data.
Returns:
True if columnwise_data buffer is initialized, False otherwise
"""
return self.columnwise_data is not None
def all_same_first_dim(self) -> bool:
"""
Check if all tensors in the group have the same first dimension.
Returns:
True if first dimension is uniform across all tensors
"""
return self.first_dims is None
def all_same_last_dim(self) -> bool:
"""
Check if all tensors in the group have the same last dimension.
Returns:
True if last dimension is uniform across all tensors
"""
return self.last_dims is None
def all_same_shape(self) -> bool:
"""
Check if all tensors in the group have identical shapes.
Returns:
True if all tensors have the same shape
"""
return self.first_dims is None and self.last_dims is None
def varying_both_dims(self) -> bool:
"""
Check if both dimensions vary across tensors.
Returns:
True if both first and last dimensions vary
"""
return self.first_dims is not None and self.last_dims is not None
def get_common_first_dim(self) -> int:
"""
Get the common first dimension when all tensors share it.
Returns:
The common first dimension
Raises:
RuntimeError: If first dimension varies across tensors or logical_shape is not 2D
"""
if not self.all_same_first_dim():
raise RuntimeError("First dim varies across tensors")
if len(self.logical_shape) != 2:
raise RuntimeError("Logical shape must be 2D")
if self.all_same_shape():
# When both dims are uniform: logical_shape = [num_tensors * M, N]
return self.logical_shape[0] // self.num_tensors
# When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims]
return self.logical_shape[0]
def get_common_last_dim(self) -> int:
"""
Get the common last dimension when all tensors share it.
Returns:
The common last dimension
Raises:
RuntimeError: If last dimension varies across tensors or logical_shape is not 2D
"""
if not self.all_same_last_dim():
raise RuntimeError("Last dim varies across tensors")
if len(self.logical_shape) != 2:
raise RuntimeError("Logical shape must be 2D")
# For both uniform and varying first dim cases: logical_shape[1] is the common last dim
return self.logical_shape[1]
def get_dtype(self) -> torch.dtype:
"""
Get the high precision data type of the tensor.
Returns:
The high precision dtype of the data buffer
"""
return self.dtype
def clear(self) -> None:
"""
Reset tensor data and clear all buffers.
"""
self.data = None
self.columnwise_data = None
self.scale_inv = None
self.columnwise_scale_inv = None
self.amax = None
self.columnwise_amax = None
self.scale = None
self.first_dims = None
self.last_dims = None
self.tensor_offsets = None
self.logical_shape = (0, 0)
self.num_tensors = 0
self.quantizer = None
self.quantized_tensors = None
self.offsets = None
self.scale_inv_offsets = None
self.columnwise_scale_inv_offsets = None
def __repr__(self) -> str:
"""String representation of the GroupedTensor."""
return (
f"GroupedTensor(num_tensors={self.num_tensors}, "
f"shape={self.shape}, "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()})"
)
def __str__(self) -> str:
"""User-friendly string representation."""
shape_info = []
if self.all_same_shape():
shape_info.append("uniform shape")
else:
if not self.all_same_first_dim():
shape_info.append("varying first dim")
if not self.all_same_last_dim():
shape_info.append("varying last dim")
return (
f"GroupedTensor with {self.num_tensors} tensors "
f"({', '.join(shape_info) if shape_info else 'uniform'}), "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()}"
)
@staticmethod
def make_grouped_tensor_with_shapes(
num_tensors: int,
shape: List[Tuple[int, int]],
quantizer: Optional[Quantizer] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> GroupedTensor:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Args:
num_tensors: Number of tensors
shape: 2D shape of each tensor (len num_tensors)
quantizer: Quantizer for each tensor
device: Device to allocate tensors on, defaults to current cuda device
dtype: Data type of the tensor (for high precision case)
Returns:
A GroupedTensor.
"""
# First dim
first_dim_list = [s[0] for s in shape]
uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list)
logical_first_dim = sum(first_dim_list)
if uniform_first_dim:
first_dims = None
else:
first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device)
# Last dim
last_dim_list = [s[1] for s in shape]
logical_last_dim = last_dim_list[0]
assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform"
return GroupedTensor.make_grouped_tensor(
num_tensors=num_tensors,
first_dims=first_dims,
last_dims=None,
logical_first_dim=logical_first_dim,
logical_last_dim=logical_last_dim,
quantizer=quantizer,
device=device,
dtype=dtype,
)
@staticmethod
def make_grouped_tensor(
num_tensors: int,
first_dims: Optional[torch.Tensor],
last_dims: Optional[torch.Tensor],
logical_first_dim: int,
logical_last_dim: int,
quantizer: Optional[Quantizer] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> GroupedTensor:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Args:
num_tensors: Number of tensors
first_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
last_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
logical_first_dim: Logical first dimension
logical_last_dim: Logical last dimension
quantizer: Quantizer for each tensor
Used to figure out the recipe and what to allocate.
device: Device to allocate tensors on, defaults to current cuda device
dtype: Data type of the tensor (for high precision case)
Returns:
A GroupedTensor.
"""
# Set device
if device is None:
device = torch.cuda.current_device()
# Shape patterns and validation.
all_same_first = first_dims is None
all_same_last = last_dims is None
assert all_same_last, "Last dim must be uniform for GroupedTensor"
assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor"
assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor"
# assert (
# logical_first_dim % 128 == 0
# ), "Logical first dim must be divisible by 128"
# assert logical_last_dim % 128 == 0, "Logical last dim must be divisible by 128"
# Calculate tensor offsets (cumulative element offsets)
tensor_offsets = None
offsets = None
shape = []
if not all_same_first:
# Need explicit offsets for non-uniform shapes
# Offsets are based on number of elements and not pointers.
# Kernels need to calculate precise pointers based on size of elements.
# TODO(ksivaman): Single kernel + remove the host offset calculation.
tensor_offsets = torch.cat(
[
torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype),
torch.cumsum(first_dims * logical_last_dim, dim=0),
]
)
offsets = tensor_offsets.tolist()
first_dims_list = first_dims.tolist()
for i in range(num_tensors):
shape.append((first_dims_list[i], logical_last_dim))
else:
offsets = [
i * logical_first_dim * logical_last_dim // num_tensors
for i in range(num_tensors + 1)
]
for i in range(num_tensors):
shape.append((logical_first_dim // num_tensors, logical_last_dim))
# Calculate logical shape based
logical_shape = (logical_first_dim, logical_last_dim)
no_quantization = quantizer is None
rowwise_usage = quantizer.rowwise_usage if not no_quantization else True
columnwise_usage = quantizer.columnwise_usage if not no_quantization else False
# Calculate total elements across all tensors
total_elements = logical_first_dim * logical_last_dim
data = None
columnwise_data = None
scale_inv = None
columnwise_scale_inv = None
amax = None
columnwise_amax = None
scale = None
scale_inv_offsets = None
columnwise_scale_inv_offsets = None
if no_quantization:
assert dtype is not None, "dtype must be provided for unquantized GroupedTensor"
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=dtype, device=device)
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=dtype, device=device)
elif quantizer._get_compatible_recipe().mxfp8():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse buffer for MXFP8 - complex shape based on block scaling
# For grouped tensors, we need to calculate scale_inv size for all tensors
total_scale_elements = 0
scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
scale_elements = math.prod(scale_inv_shape)
total_scale_elements += scale_elements
if i < num_tensors - 1:
scale_inv_offsets.append(total_scale_elements)
scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device)
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Columnwise scale inverse buffer
total_columnwise_scale_elements = 0
columnwise_scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
columnwise_scale_elements = math.prod(scale_inv_shape)
total_columnwise_scale_elements += columnwise_scale_elements
if i < num_tensors - 1:
columnwise_scale_inv_offsets.append(total_columnwise_scale_elements)
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
elif quantizer._get_compatible_recipe().delayed():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse - one per tensor
scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
scale_inv_offsets = list(range(num_tensors))
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Columnwise scale inverse - one per tensor
columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
columnwise_scale_inv_offsets = list(range(num_tensors))
# Amax buffer for delayed scaling - one per tensor
amax = torch.empty(num_tensors, dtype=torch.float32, device=device)
elif quantizer._get_compatible_recipe().nvfp4():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte)
data = torch.empty((total_elements) // 2, dtype=torch.uint8, device=device)
# Scale inverse buffer for NVFP4 - complex shape based on block scaling
# For simplicity, calculate total scale elements needed
total_scale_elements = 0
scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
total_scale_elements += math.prod(scale_inv_shape)
if i < num_tensors - 1:
scale_inv_offsets.append(total_scale_elements)
scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device)
# Amax buffer - one per tensor
amax = torch.empty(num_tensors, dtype=torch.float32, device=device)
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8, FP4 packed)
columnwise_data = torch.empty(
(total_elements) // 2, dtype=torch.uint8, device=device
)
# Columnwise scale inverse buffer
total_columnwise_scale_elements = 0
columnwise_scale_inv_offsets = [0]
for i, s in enumerate(shape):
columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True)
total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape)
if i < num_tensors - 1:
columnwise_scale_inv_offsets.append(total_columnwise_scale_elements)
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
# Columnwise amax buffer - one per tensor
columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device)
elif quantizer._get_compatible_recipe().float8_block_scaling():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse - size depends on block configuration
# For simplicity, calculate total scale elements needed
total_scale_elements = 0
scale_inv_offsets = [0]
for i, s in enumerate(shape):
scale_inv_shape = quantizer.get_scale_shape(s, False)
total_scale_elements += math.prod(scale_inv_shape)
if i < num_tensors - 1:
scale_inv_offsets.append(total_scale_elements)
scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device)
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Columnwise scale inverse
total_columnwise_scale_elements = 0
columnwise_scale_inv_offsets = [0]
for i, s in enumerate(shape):
columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True)
total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape)
if i < num_tensors - 1:
columnwise_scale_inv_offsets.append(total_columnwise_scale_elements)
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.float32, device=device
)
elif quantizer._get_compatible_recipe().float8_current_scaling():
# Current scaling - per-tensor scaling computed on the fly
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse - one per tensor
scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
scale_inv_offsets = list(range(num_tensors))
if columnwise_usage:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Columnwise scale inverse - one per tensor
columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
columnwise_scale_inv_offsets = list(range(num_tensors))
# Scale and amax buffers for current scaling - one per tensor
scale = torch.empty(num_tensors, dtype=torch.float32, device=device)
amax = torch.empty(num_tensors, dtype=torch.float32, device=device)
else:
raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}")
grouped_tensor = GroupedTensor(
num_tensors=num_tensors,
shape=shape,
dtype=dtype,
quantizer=quantizer,
data=data,
columnwise_data=columnwise_data,
scale_inv=scale_inv,
columnwise_scale_inv=columnwise_scale_inv,
amax=amax,
columnwise_amax=columnwise_amax,
scale=scale,
first_dims=first_dims,
last_dims=last_dims,
tensor_offsets=tensor_offsets,
offsets=offsets,
scale_inv_offsets=scale_inv_offsets,
columnwise_scale_inv_offsets=columnwise_scale_inv_offsets,
logical_shape=logical_shape,
)
grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors()
return grouped_tensor
def split_into_quantized_tensors(
self,
) -> List[Union[QuantizedTensorStorage, torch.Tensor]]:
"""
Split the GroupedTensor into a list of `num_tensors`
quantized tensors based on the quantizer. No additional memory allocation is performed,
so the tensors returned are the same as the ones used to create the GroupedTensor.
If quantizer is None, returns normal torch tensors.
If quantizer.internal is True, returns QuantizedTensorStorage.
Otherwise, returns QuantizedTensor.
TODO(ksivaman): Block cases where any dims are varying. This is needed only
to expose the weights as separate parameters.
"""
result = []
no_quantization = self.quantizer is None
# Case 1: No quantization - return regular torch tensors
if no_quantization:
for i in range(self.num_tensors):
# Get tensor shape
tensor_shape = self.shape[i]
# Get tensor data slice
if self.offsets is not None:
start_offset = self.offsets[i]
numel = tensor_shape[0] * tensor_shape[1]
end_offset = start_offset + numel
if self.has_data():
tensor_data = self.data[start_offset:end_offset].view(tensor_shape)
result.append(tensor_data)
elif self.has_columnwise_data():
tensor_data = self.columnwise_data[start_offset:end_offset].view(
tensor_shape
)
result.append(tensor_data)
else:
raise RuntimeError("GroupedTensor has no data to split")
else:
# All same shape case
numel = tensor_shape[0] * tensor_shape[1]
start_offset = i * numel
end_offset = start_offset + numel
if self.has_data():
tensor_data = self.data[start_offset:end_offset].view(tensor_shape)
result.append(tensor_data)
elif self.has_columnwise_data():
tensor_data = self.columnwise_data[start_offset:end_offset].view(
tensor_shape
)
result.append(tensor_data)
else:
raise RuntimeError("GroupedTensor has no data to split")
return result
# Case 2: Quantized tensors
recipe = self.quantizer._get_compatible_recipe()
for i in range(self.num_tensors):
# Get tensor shape
tensor_shape = self.shape[i]
numel = tensor_shape[0] * tensor_shape[1]
# Get data offsets
if self.offsets is not None:
data_start = self.offsets[i]
data_end = data_start + numel
else:
# All same shape
data_start = i * numel
data_end = data_start + numel
# Special shape handling for NVFP4.
nvfp4 = self.quantizer._get_compatible_recipe().nvfp4()
if nvfp4:
data_start = data_start // 2
data_end = data_end // 2
# Extract rowwise and columnwise data
rowwise_data = None
columnwise_data = None
if self.has_data():
if nvfp4:
rowwise_tensor_shape = self.quantizer.convert_shape_for_fp4(tensor_shape)
else:
rowwise_tensor_shape = tensor_shape
rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape)
if self.has_columnwise_data():
columnwise_tensor_shape = self.quantizer.get_columnwise_shape(tensor_shape)
if nvfp4:
columnwise_tensor_shape = self.quantizer.convert_shape_for_fp4(
columnwise_tensor_shape
)
columnwise_data = self.columnwise_data[data_start:data_end].view(
columnwise_tensor_shape
)
# MXFP8 format
if recipe.mxfp8():
# Extract scale_inv data
rowwise_scale_inv = None
columnwise_scale_inv = None
if self.scale_inv is not None and self.scale_inv_offsets is not None:
scale_start = self.scale_inv_offsets[i]
if i < self.num_tensors - 1:
scale_end = self.scale_inv_offsets[i + 1]
else:
scale_end = self.scale_inv.numel()
# Calculate expected scale shape for MXFP8
scale_shape = self.quantizer.get_scale_shape(tensor_shape, False)
rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape)
if (
self.columnwise_scale_inv is not None
and self.columnwise_scale_inv_offsets is not None
):
cscale_start = self.columnwise_scale_inv_offsets[i]
if i < self.num_tensors - 1:
cscale_end = self.columnwise_scale_inv_offsets[i + 1]
else:
cscale_end = self.columnwise_scale_inv.numel()
cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True)
columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view(
cscale_shape
)
if self.quantizer.internal:
mxfp8_tensor_class = MXFP8TensorStorage
else:
mxfp8_tensor_class = MXFP8Tensor
tensor = mxfp8_tensor_class(
shape=tensor_shape,
dtype=self.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=self.quantizer.dtype,
quantizer=self.quantizer,
with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm,
)
result.append(tensor)
# Delayed scaling or current scaling (both use Float8TensorStorage)
elif recipe.delayed() or recipe.float8_current_scaling():
# Scale inverse - one per tensor
scale_inv = None
if self.scale_inv is not None:
scale_inv = self.scale_inv[i : i + 1]
if self.quantizer.internal:
float8_tensor_class = Float8TensorStorage
else:
float8_tensor_class = Float8Tensor
tensor = float8_tensor_class(
shape=tensor_shape,
dtype=self.dtype,
data=rowwise_data,
fp8_scale_inv=scale_inv,
fp8_dtype=self.quantizer.dtype,
quantizer=self.quantizer,
data_transpose=columnwise_data,
)
result.append(tensor)
# Float8 block scaling
elif recipe.float8_block_scaling():
# Extract scale_inv data
rowwise_scale_inv = None
columnwise_scale_inv = None
if self.scale_inv is not None and self.scale_inv_offsets is not None:
scale_start = self.scale_inv_offsets[i]
if i < self.num_tensors - 1:
scale_end = self.scale_inv_offsets[i + 1]
else:
scale_end = self.scale_inv.numel()
# Get scale shape from quantizer
scale_shape = self.quantizer.get_scale_shape(tensor_shape, False)
rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape)
if (
self.columnwise_scale_inv is not None
and self.columnwise_scale_inv_offsets is not None
):
cscale_start = self.columnwise_scale_inv_offsets[i]
if i < self.num_tensors - 1:
cscale_end = self.columnwise_scale_inv_offsets[i + 1]
else:
cscale_end = self.columnwise_scale_inv.numel()
# Get columnwise scale shape from quantizer
cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True)
columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view(
cscale_shape
)
# Compute is_2D_scaled and data_format from quantizer attributes
is_2D_scaled = self.quantizer.block_scaling_dim == 2
if self.quantizer.internal:
float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage
else:
float8_blockwise_q_tensor_class = Float8BlockwiseQTensor
tensor = float8_blockwise_q_tensor_class(
shape=tensor_shape,
dtype=self.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=self.quantizer.dtype,
quantizer=self.quantizer,
is_2D_scaled=is_2D_scaled,
)
result.append(tensor)
# NVFP4 format
elif recipe.nvfp4():
# Extract scale_inv data
rowwise_scale_inv = None
columnwise_scale_inv = None
amax_rowwise = None
amax_columnwise = None
if self.scale_inv is not None and self.scale_inv_offsets is not None:
scale_start = self.scale_inv_offsets[i]
if i < self.num_tensors - 1:
scale_end = self.scale_inv_offsets[i + 1]
else:
scale_end = self.scale_inv.numel()
# Get scale shape from quantizer
scale_shape = self.quantizer.get_scale_shape(tensor_shape, False)
rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape)
if (
self.columnwise_scale_inv is not None
and self.columnwise_scale_inv_offsets is not None
):
cscale_start = self.columnwise_scale_inv_offsets[i]
if i < self.num_tensors - 1:
cscale_end = self.columnwise_scale_inv_offsets[i + 1]
else:
cscale_end = self.columnwise_scale_inv.numel()
# Get columnwise scale shape from quantizer
cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True)
columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view(
cscale_shape
)
# Extract amax - one per tensor
if self.amax is not None:
amax_rowwise = self.amax[i : i + 1]
if self.columnwise_amax is not None:
amax_columnwise = self.columnwise_amax[i : i + 1]
if self.quantizer.internal:
nvfp4_tensor_class = NVFP4TensorStorage
else:
nvfp4_tensor_class = NVFP4Tensor
tensor = nvfp4_tensor_class(
shape=tensor_shape,
dtype=self.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
amax_rowwise=amax_rowwise,
amax_columnwise=amax_columnwise,
fp4_dtype=self.quantizer.dtype,
quantizer=self.quantizer,
with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm,
)
result.append(tensor)
else:
raise ValueError(f"Unsupported quantization recipe: {recipe}")
return result
@staticmethod
def create_and_quantize(
tensors: int,
quantizer: None | Quantizer,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
noop_flag: Optional[torch.Tensor] = None,
) -> Tuple[QuantizedTensorStorage, ...]:
"""
Quantize given tensors into quantized tensors with underlying
storage allocated in a GroupedTensor.
"""
grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=len(tensors),
shape=[t.shape for t in tensors],
quantizer=quantizer,
device=device,
dtype=dtype,
)
grouped_tensor.quantize(tensors, noop_flag=noop_flag)
return grouped_tensor
def quantize(
self,
tensors: List[torch.Tensor],
noop_flag: Optional[torch.Tensor] = None,
) -> Tuple[QuantizedTensorStorage, ...]:
"""
Quantize the GroupedTensor inplace.
"""
quantized_tensors = self.split_into_quantized_tensors()
for i in range(self.num_tensors):
self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag)
return quantized_tensors
......@@ -111,6 +111,24 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
if t is not None:
t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another MXFP8TensorStorage."""
if not isinstance(src, MXFP8TensorStorage):
raise TypeError("copy_from_storage expects MXFP8TensorStorage")
if self._fp8_dtype != src._fp8_dtype:
raise RuntimeError("FP8 dtype mismatch in copy_from_storage")
if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......
......@@ -136,6 +136,26 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
if t is not None:
t.data = _empty_tensor()
def copy_from_storage(self, src: QuantizedTensorStorage) -> None:
"""Copy data buffers from another NVFP4TensorStorage."""
if not isinstance(src, NVFP4TensorStorage):
raise TypeError("copy_from_storage expects NVFP4TensorStorage")
if self._fp4_dtype != src._fp4_dtype:
raise RuntimeError("FP4 dtype mismatch in copy_from_storage")
if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales:
raise RuntimeError("Scale layout mismatch in copy_from_storage")
def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]):
if dst is not None and src_tensor is not None:
dst.copy_(src_tensor)
_copy_optional(self._rowwise_data, src._rowwise_data)
_copy_optional(self._columnwise_data, src._columnwise_data)
_copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv)
_copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv)
_copy_optional(self._amax_rowwise, src._amax_rowwise)
_copy_optional(self._amax_columnwise, src._amax_columnwise)
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment