Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass."""
from __future__ import annotations
from contextlib import nullcontext
from typing import Any, Dict, Optional
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .quantized_tensor import QuantizedTensorStorage
from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False
CPUOffloadedLayer = False
def get_cpu_offloading():
global CPUOffloadEnabled
return CPUOffloadEnabled
def set_cpu_offloading(cpu_offloading):
global CPUOffloadEnabled
CPUOffloadEnabled = cpu_offloading
def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor."""
if TEDebugState.debug_enabled:
raise RuntimeError("CPU offload is not supported in debug mode.")
for tensor in tensors:
if tensor is None:
continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
tensor.activation_offloading = True
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled
def is_current_layer_offloaded() -> bool:
"""Check if current layers is being offloaded."""
return CPUOffloadedLayer
class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
In this context, the ``on_save_for_backward`` method will be called every time
a tensor is saved for backward (this includes intermediary results saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation).
The ``on_get_saved_tensors`` method will be called when the backward function
of this op attempts to retrieve the saved tensor from context (this includes
:func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
as input the return value of the ``on_save_for_backward``, and is meant to return
an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
size, device and element values.
Example:
>>> import torch
>>> from typing import Any
>>>
>>> class DummyHook(CpuOffloadSavedTensorHook):
...
... def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
... logging.info("On save", tensor)
... return (tensor,)
...
... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
... logging.info("On get", saved_state)
... tensor, = saved_state
... return tensor
...
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with DummyHook():
... y = a * b
...
On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)
"""
def __init__(self) -> None:
self.inside_context = False
def __enter__(self):
global CPUOffloadEnabled
CPUOffloadEnabled = True
self.inside_context = True
torch._C._autograd._push_saved_tensors_default_hooks(
self.on_save_for_backward, self.on_get_saved_tensor
)
def __exit__(self, *args: Any):
global CPUOffloadEnabled
CPUOffloadEnabled = False
self.inside_context = False
torch._C._autograd._pop_saved_tensors_default_hooks()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
"""On save for backward."""
raise NotImplementedError(
"`on_save_for_backward: Callable[[torch.Tensor], Any]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
"""On get saved tensor."""
raise NotImplementedError(
"`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
"""Context-manager that offloads/recovers tensors through an offload hander.
The hook just offloads/recovers the tensor object to the handler through `tensor_push`
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook.
"""
def __init__(
self,
offload_handler: OffloadHandler,
handler_extra_kwargs: Optional[Dict[str, Any]] = None,
debug: bool = False,
) -> None:
if handler_extra_kwargs is None:
handler_extra_kwargs = {}
self.debug: bool = debug
self.offload_handler: OffloadHandler = offload_handler
self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs
super().__init__()
def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
return retrieve_identifier
def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs)
return tensor
class OffloadHandler:
"""A base class for CPU offload-handler."""
def __init__(self) -> None:
pass
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
"""Tensor push."""
raise NotImplementedError(
"`tensor_push is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_push."
)
def tensor_pop(self, tensor_tag: Any, **kwargs):
"""Tensor pop."""
raise NotImplementedError(
"`tensor_pop is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_pop."
)
class GroupCommitFunction(torch.autograd.Function):
"""this is a dummy op with output identical to input.
However, it is necessary for marking a timepoint for offload handler to
accomplish all synchronizations. Implementing it as a function is necessary
because we need to actions in both forward and backward.
"""
@staticmethod
def forward(ctx, tensor, cpu_offload_handler):
# pylint: disable=missing-function-docstring
cpu_offload_handler.on_group_commit_forward()
ctx.cpu_offload_handler = cpu_offload_handler
# return the identical tensor
return tensor
@staticmethod
def backward(ctx, grad_output):
# pylint: disable=missing-function-docstring
cpu_offload_handler = ctx.cpu_offload_handler
cpu_offload_handler.on_group_commit_backward()
return grad_output, None
group_prefetch_offload_commit = GroupCommitFunction.apply
class SynchronizedGroupOffloadHandler(OffloadHandler):
"""Offload Handler that offloads/reloads in a synchronized way.
The device-to-host and host-to-device copying happen in the same stream
as the computation kernels, thus the copying will block computation.
"""
def __init__(
self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False
) -> None:
super().__init__()
self.num_offload_group = num_offload_group
self.tensor_need_offloading_checker = tensor_need_offloading_checker
self.debug = debug
self.groupid_reset()
def groupid_reset(self):
"""Groupid reset."""
# Data structures to label saved tensors and book-keep their cpu copies.
# Currently, on push, create a new cpu tensor and copies; on pop, copies
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0)
self.torch_tensor_count = 0
self.tensor_tag_to_state = {}
def on_group_commit_forward(self):
"""On group commit forward."""
# finishing up with updating current group and tensor count
self.current_group += 1 # increment
self.tensor_count_current_group = 0 # reset
def on_group_commit_backward(self):
"""On group commit backward."""
self.current_group -= 1
assert self.current_group >= 0
@staticmethod
def offload(src_tensor, pin_memory=True):
"""Offload."""
cpu_backup = torch.empty(
src_tensor.size(),
dtype=src_tensor.dtype,
layout=src_tensor.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
state = (src_tensor.device, cpu_backup)
return state
@staticmethod
def reload(state, non_blocking=None, copy_buffer=None):
"""Reload."""
dev, cpu_backup = state
if non_blocking is None:
non_blocking = cpu_backup.is_pinned()
if copy_buffer is None:
return cpu_backup.to(dev, non_blocking=non_blocking)
assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!"
copy_buffer.copy_(cpu_backup, non_blocking=non_blocking)
return copy_buffer
def tensor_push(self, tensor: torch.Tensor, **kwargs):
"""Tensor push."""
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
tensor
):
state = SynchronizedGroupOffloadHandler.offload(tensor)
self.tensor_tag_to_state[tensor_tag] = state
else:
# will be offloaded together after group commit
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state
state = self.tensor_tag_to_state.pop(tensor_tag)
if isinstance(state, tuple):
tensor = SynchronizedGroupOffloadHandler.reload(state)
else:
tensor = state
return tensor
class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Compared to synchronize, this uses more memory because of the buffer but
achieves better performance due to the overlapping. D2h and h2d copying are
completely hidden behind computation if computation time of a layer is longer
than host-device communication time. Bulk offloading with delay and bulk reloading
with prefetch are implemented."""
def __init__(
self,
num_offload_group, # must be <= actual number of groups (number of commits)
num_model_group,
tensor_need_offloading_checker=(lambda t: True),
double_buffering=False,
debug=False,
) -> None:
super().__init__(
num_offload_group=num_offload_group,
tensor_need_offloading_checker=tensor_need_offloading_checker,
debug=debug,
)
# Number of layers in the model
self.num_layers = num_model_group
# Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {}
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {}
self.dereferencing_list = []
# Tracking the number of layers offloaded
self.offloaded_group_count = 0
# Core data structure that decides the window for offloading
self.layer_window_map = {}
# Data structures fo double buffered reloading
self.double_buffering = double_buffering
self.reload_double_buffer = [[], []]
self.double_buffer_created = False
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
constant = 0
for i in range(self.num_offload_group):
self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1
if i < (self.num_layers % self.num_offload_group):
self.layer_window_map[i] += i + 1
constant = i + 1
else:
self.layer_window_map[i] += constant
# allocate streams and events for synchronization
self.d2h_stream = torch.cuda.Stream()
self.h2d_stream = torch.cuda.Stream()
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
global CPUOffloadedLayer
torch_stray_tensor = isinstance(
tensor,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
)
is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage)
if not torch_stray_tensor:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if is_quantized_tensor:
tensor_list, _ = tensor.prepare_for_saving()
self.tensor_tag_to_state[tensor_tag] = []
self.tensor_tag_to_buf[tensor_tag] = []
# Added support for de-duplicating FP8 param tensors
for _, value in self.fp8_tensor_object_map.items():
if tensor is value:
self.dereferencing_list.append(tensor_tag)
break
self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr(
tensor, "_transpose_invalid"
)
else:
tensor_list = [tensor]
for t in tensor_list:
if is_quantized_tensor:
self.tensor_tag_to_state[tensor_tag].append(t)
else:
self.tensor_tag_to_state[tensor_tag] = t
if (
self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(t)
):
if is_quantized_tensor:
self.tensor_tag_to_buf[tensor_tag].append(t)
# Need to clear the internal data reference for the quantized tensors
tensor.clear()
else:
self.tensor_tag_to_buf[tensor_tag] = t
# Needed to differentiate non offloaded layer's attention
# QKV layout of attention of non-offloaded layer needs
# to be modified while reloading
CPUOffloadedLayer = True
else:
tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop."""
global CPUOffloadedLayer
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)
# Handling the quantized tensor case specially here
if isinstance(tensor, list):
# If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if tensor_tag in self.dereferencing_list:
self.dereferencing_list.remove(tensor_tag)
else:
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag)
if self.double_buffering:
tensor._do_not_clear = True
self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
assert not isinstance(tensor, tuple)
return tensor
def bulk_offload_group(self, group_to_offload):
"""Bulk offload group."""
with torch.cuda.stream(self.d2h_stream):
for tensor_tag, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_tag
if group_id == group_to_offload:
assert not isinstance(state, tuple)
is_quantized_tensor = isinstance(state, list)
if is_quantized_tensor:
tensor_list = state
self.tensor_tag_to_state[tensor_tag] = []
else:
tensor_list = [state]
for tensor_on_device in tensor_list:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded = False
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
tensor_offloaded = True
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
if is_quantized_tensor:
if tensor_offloaded:
self.tensor_tag_to_state[tensor_tag].append(state)
else:
self.tensor_tag_to_state[tensor_tag].append(tensor_on_device)
else:
self.tensor_tag_to_state[tensor_tag] = state
def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward."""
global CPUOffloadedLayer
# For the first group, kickstart the offload after we have
# the first compute completion
if current_group == 0:
self.d2h_stream.wait_stream(torch.cuda.current_stream())
if not self.double_buffer_created:
# Creating the first copy of double buffer for tensors that are offloaded
for tensor_tag, buf in self.tensor_tag_to_buf.items():
if isinstance(buf, list):
for b in buf:
self.reload_double_buffer[0].append(
torch.empty_like(b) if self.double_buffering else None
)
else:
self.reload_double_buffer[0].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.bulk_offload_group(current_group)
# Window map data structure helps us synchronize based on number
# of layers offloaded
if self.layer_window_map[self.offloaded_group_count] == current_group:
# Stream synchronization both ways
self.d2h_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group
if self.offloaded_group_count < (self.num_offload_group - 1):
self.bulk_offload_group(self.offloaded_group_count + 1)
# Increment the offload group count to keep track
self.offloaded_group_count += 1
if current_group == (self.num_offload_group - 1):
CPUOffloadedLayer = False
if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1):
for buf in self.reload_double_buffer[0]:
self.reload_double_buffer[1].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.double_buffer_created = True
def on_group_commit_forward(self):
"""This function will cause host device synchronization"""
# handle synchronization events
self.synchronize_on_group_commit_forward(self.current_group)
super().on_group_commit_forward()
def bulk_reload_group(self, group_to_reload):
"""Bulk reload group."""
assert group_to_reload < self.num_offload_group
buffer_idx = 0
double_buffer_idx = group_to_reload % 2
main_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.h2d_stream):
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if isinstance(state, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state[1], device=torch.cuda.current_device()
)
recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, reload_buffer
)
buffer_idx = buffer_idx + 1
self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if isinstance(state_tuple, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state_tuple[1], device=torch.cuda.current_device()
)
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(
state_tuple,
True,
reload_buffer,
)
)
buffer_idx = buffer_idx + 1
else:
tensor_list.append(state_tuple)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if tensor_label in self.dereferencing_list:
self.dereferencing_list.remove(tensor_label)
else:
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(
tensor_list
)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label)
)
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label
)
def on_group_commit_backward(self):
# first decrement the current group.
# after last commit in forward, the group will +1; in backward it -1.
# Finally it should be decremented to 0.
self.current_group -= 1
assert self.current_group >= 0
# Layer window data structure helps us to reload at right times
if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group:
# Stream synchronization both ways
self.h2d_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().wait_stream(self.h2d_stream)
# Time to reload the next group
self.bulk_reload_group(self.offloaded_group_count - 1)
# Decrease the offloading group counter
self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0
# Last group computation needs to wait till all the reloads complete
if self.current_group == 0:
torch.cuda.current_stream().wait_stream(self.h2d_stream)
self.offloaded_group_count = 0
def get_cpu_offload_context(
enabled: bool = False,
num_layers: int = 1,
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = False,
double_buffering: bool = False,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.
Usage:
.. code-block:: python
cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)
with cpu_offload_context:
te_layer.forward(inp_tensor)
cpu_offload_synchronizer()
Parameters
----------
enabled: bool, default = `False`
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
Determines the number of transformer layers
you want to offload activations/weights for.
model_layers: int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading.
"""
if not offload_weights and not offload_activations:
raise ValueError(
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)"
)
if offload_weights:
import warnings
warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.",
DeprecationWarning,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
return nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_model_group=model_layers,
tensor_need_offloading_checker=tensor_need_offloading_checker,
double_buffering=double_buffering,
)
def group_prefetch_offload_commit_async(tensor):
return group_prefetch_offload_commit(tensor, cpu_offload_handler)
if enabled:
return (
CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler),
group_prefetch_offload_commit_async,
)
return nullcontext(), group_prefetch_offload_commit_async
...@@ -190,8 +190,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( ...@@ -190,8 +190,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
const std::vector<size_t> meta_shape{1}; const std::vector<size_t> meta_shape{1};
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
auto scale_inv_dtype = auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3
: DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype,
columnwise_scale_inv_shape); columnwise_scale_inv_shape);
......
...@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right, bool return_max_logit, bool cuda_graph);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer, std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype, const std::vector<size_t> &shape, DType dtype,
...@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen, const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread); size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
...@@ -106,7 +106,7 @@ std::vector<py::object> fused_attn_bwd( ...@@ -106,7 +106,7 @@ std::vector<py::object> fused_attn_bwd(
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer); py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph);
at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
...@@ -384,6 +384,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -384,6 +384,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const int cp_rank); const int cp_rank);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank); const int cp_rank);
......
...@@ -163,6 +163,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward( ...@@ -163,6 +163,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
} }
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) { const int cp_rank) {
...@@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto freqs_cu = makeTransformerEngineTensor(freqs); auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads); auto input_grads_cu = makeTransformerEngineTensor(input_grads);
auto start_positions_cu = TensorWrapper(); // empty start_positions tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor");
}
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
...@@ -208,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -208,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, start_positions_cu.data(), input_grads_cu.data(), qkv_format,
max_s, b, h, d, d2, stride_t, interleaved, cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
/*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return input_grads; return input_grads;
...@@ -246,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -246,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, start_positions_cu.data(), input_grads_cu.data(), qkv_format,
h, d, d2, stride_s, stride_b, stride_h, stride_d, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b,
at::cuda::getCurrentCUDAStream()); stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return input_grads; return input_grads;
} }
......
...@@ -45,14 +45,15 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -45,14 +45,15 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right, bool return_max_logit, bool cuda_graph) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
return NVTE_Fused_Attn_Backend::NVTE_No_Backend; return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
#else #else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right,
return_max_logit, cuda_graph);
return fused_attention_backend; return fused_attention_backend;
#endif #endif
} }
...@@ -110,7 +111,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -110,7 +111,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen, const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) { size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
#else #else
...@@ -235,8 +236,9 @@ std::vector<py::object> fused_attn_fwd( ...@@ -235,8 +236,9 @@ std::vector<py::object> fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); softmax_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
}); });
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
...@@ -256,7 +258,9 @@ std::vector<py::object> fused_attn_fwd( ...@@ -256,7 +258,9 @@ std::vector<py::object> fused_attn_fwd(
}; };
// allocate memory for nvte_aux_tensor_pack.tensors // allocate memory for nvte_aux_tensor_pack.tensors
// f16_max512 : S [b, h, sq, skv] // f16_max512 : S [b, h, sq, skv]
// f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] // f16_arbitrary:
// return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t i = 0; size_t i = 0;
at::Tensor output_tensor; at::Tensor output_tensor;
...@@ -265,8 +269,8 @@ std::vector<py::object> fused_attn_fwd( ...@@ -265,8 +269,8 @@ std::vector<py::object> fused_attn_fwd(
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor); set_tensor_param(i++, output_tensor);
// fp8 has an additional softmax stats tensor, ZInv // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
output_tensor = output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
...@@ -292,8 +296,9 @@ std::vector<py::object> fused_attn_fwd( ...@@ -292,8 +296,9 @@ std::vector<py::object> fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); softmax_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
}); });
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
...@@ -315,7 +320,7 @@ std::vector<py::object> fused_attn_bwd( ...@@ -315,7 +320,7 @@ std::vector<py::object> fused_attn_bwd(
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) { py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
#else #else
...@@ -533,13 +538,14 @@ std::vector<py::object> fused_attn_bwd( ...@@ -533,13 +538,14 @@ std::vector<py::object> fused_attn_bwd(
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd( nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(),
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); softmax_type, window_size[0], window_size[1], deterministic, cuda_graph,
workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
// allocate memory for workspace // allocate memory for workspace
...@@ -549,13 +555,14 @@ std::vector<py::object> fused_attn_bwd( ...@@ -549,13 +555,14 @@ std::vector<py::object> fused_attn_bwd(
// execute kernel // execute kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd( nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(),
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); softmax_type, window_size[0], window_size[1], deterministic, cuda_graph,
workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
// destroy tensor wrappers // destroy tensor wrappers
......
...@@ -491,6 +491,207 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -491,6 +491,207 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
return retval; return retval;
} }
// allocate fp4 data, fp8 scalings, and amax values
// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN]
// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_nvfp4_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<NVFP4Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
// Number of tensors
const size_t num_tensors = shape_list.size();
if (num_tensors == 0) {
return retval;
}
// Quantization parameters
const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage;
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto fp4_dtype = quantizer_cpp_list[0]->dtype;
constexpr size_t scale_elem_size = 1;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
size_t offset, at::ScalarType dtype) -> at::Tensor {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
bool is_empty_shape = product(shape) == 0;
if (buffer->data_ptr<uint8_t>() == nullptr || is_empty_shape) {
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
}
return at::from_blob(
buffer->data_ptr<uint8_t>() + offset, shape_int64,
[buffer](void *) {}, // deleter holds shared_ptr
at::device(at::kCUDA).dtype(dtype));
};
// Lambda function for converting std::vector<size_t> shape to NVFP4 shape (last dim divided by 2)
auto to_fp4_shape = [](const std::vector<size_t> &shape) {
std::vector<size_t> fp4_shape(shape.begin(), shape.end());
if (!fp4_shape.empty()) {
fp4_shape.back() /= 2;
}
return fp4_shape;
};
// Allocate row-wise data
std::vector<at::Tensor> rowwise_data_list, rowwise_scale_list, amax_rowwise_list;
std::vector<std::vector<size_t>> rowwise_data_shapes, rowwise_scale_shapes;
if (rowwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_shapes.emplace_back(shape_list[i]);
rowwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets, amax_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
// Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes).
// Integer arithmetic: ceil(product / 2) == (product + 1) / 2.
buffer_size += (product(rowwise_data_shapes[i]) + 1) / 2;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
amax_offsets.push_back(buffer_size);
// amax is scalar in fp32, 4 bytes each
buffer_size += 4;
}
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]),
data_offsets[i], torch::kUInt8));
rowwise_scale_list.emplace_back(
make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
amax_rowwise_list.emplace_back(
make_torch_view(buffer, std::vector<size_t>{1}, amax_offsets[i], torch::kUInt8));
}
}
// Allocate column-wise data
std::vector<at::Tensor> columnwise_data_list, columnwise_scale_list, amax_columnwise_list;
std::vector<std::vector<size_t>> columnwise_data_shapes, columnwise_scale_shapes;
if (columnwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
// push the transposed shape into NVFP4 columnwise shape
// NVFP4 on SM100 is TN only
columnwise_data_shapes.emplace_back();
auto &shape = columnwise_data_shapes.back();
shape.push_back(shape_list[i].back());
for (size_t j = 0; j < shape_list[i].size() - 1; ++j) {
shape.push_back(shape_list[i][j]);
}
columnwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets, amax_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
// Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes).
// Integer arithmetic: ceil(product / 2) == (product + 1) / 2.
buffer_size += (product(columnwise_data_shapes[i]) + 1) / 2;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
amax_offsets.push_back(buffer_size);
// amax is scalar in fp32, 4 bytes each
buffer_size += 4;
}
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
columnwise_data_list.emplace_back(make_torch_view(
buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8));
columnwise_scale_list.emplace_back(
make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
amax_columnwise_list.emplace_back(
make_torch_view(buffer, std::vector<size_t>{1}, amax_offsets[i], torch::kUInt8));
}
}
// Construct nvfp4 tensors
py::handle NVFP4TensorClass(reinterpret_cast<PyObject *>(NVFP4TensorStoragePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none();
py::object columnwise_data =
(columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none());
py::object columnwise_scale =
(columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none());
py::object amax_rowwise = rowwise_usage ? py::cast(amax_rowwise_list[i]) : py::none();
py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none();
// Construct Python tensor
tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data,
columnwise_scale, amax_rowwise, amax_columnwise,
fp4_dtype, quantizer_py_list[i]));
// Construct C++ tensor
// Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor,
// then set the amax and amax_columnwise values.
{
auto tensor_wrapper = makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp4_dtype,
/*amax_ptr=*/nullptr,
/*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode);
// Set the amax rowwise and amax columnwise if available
if (rowwise_usage) {
tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
if (columnwise_usage) {
tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32,
std::vector<size_t>{1});
}
tensor_cpp_list.emplace_back(std::move(tensor_wrapper));
}
}
return retval;
}
} // namespace } // namespace
std::vector<py::object> split_quantize(const at::Tensor &tensor, std::vector<py::object> split_quantize(const at::Tensor &tensor,
...@@ -549,7 +750,8 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -549,7 +750,8 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
bool use_fused_bulk_alloc = true; bool use_fused_bulk_alloc = true;
for (size_t i = 0; i < quantizer_list.size(); i++) { for (size_t i = 0; i < quantizer_list.size(); i++) {
if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) && if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) &&
!detail::IsMXFP8Quantizers(quantizer_list[i].ptr())) { !detail::IsMXFP8Quantizers(quantizer_list[i].ptr()) &&
!detail::IsNVFP4Quantizers(quantizer_list[i].ptr())) {
use_fused_bulk_alloc = false; use_fused_bulk_alloc = false;
break; break;
} }
...@@ -570,6 +772,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -570,6 +772,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
// TODO(zhongbo): make a better api to make this part less hacky // TODO(zhongbo): make a better api to make this part less hacky
bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr()); bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr());
bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr()); bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr());
bool is_nvfp4 = detail::IsNVFP4Quantizers(quantizer_list[0].ptr());
if (is_fp8_blockwise) { if (is_fp8_blockwise) {
// FP8 block-scaling: construct output tensors with bulk allocations // FP8 block-scaling: construct output tensors with bulk allocations
std::vector<Float8BlockQuantizer *> blockwise_quantizers; std::vector<Float8BlockQuantizer *> blockwise_quantizers;
...@@ -586,6 +789,14 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -586,6 +789,14 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
} }
std::tie(output_py_list, output_cpp_list) = std::tie(output_py_list, output_cpp_list) =
bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers);
} else if (is_nvfp4) {
// NVFP4: construct output tensors with bulk allocations
std::vector<NVFP4Quantizer *> nvfp4_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
nvfp4_quantizers.push_back(static_cast<NVFP4Quantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers);
} else { } else {
NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer");
} }
......
...@@ -20,10 +20,11 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { ...@@ -20,10 +20,11 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor");
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
auto* amax_ptr = amax.data_ptr<float>();
TensorWrapper fake_te_output( TensorWrapper fake_te_output(
nullptr, te_input.shape(), nullptr, te_input.shape(),
DType::kFloat8E4M3, // It doesn't matter because we only compute amax. DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
amax.data_ptr<float>()); amax_ptr);
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
} }
......
...@@ -1200,6 +1200,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1200,6 +1200,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
rowwise_scale_inv_shape.end()); rowwise_scale_inv_shape.end());
rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts);
rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_rowwise = at::empty({1}, bit32_tensor_opts); amax_rowwise = at::empty({1}, bit32_tensor_opts);
} }
if (columnwise_usage) { if (columnwise_usage) {
...@@ -1213,6 +1215,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve ...@@ -1213,6 +1215,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
columnwise_data_tensor = columnwise_data_tensor =
at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts);
columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_columnwise = at::empty({1}, bit32_tensor_opts); amax_columnwise = at::empty({1}, bit32_tensor_opts);
} }
...@@ -1352,6 +1356,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor( ...@@ -1352,6 +1356,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
} }
if (!amax_rowwise) { if (!amax_rowwise) {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_rowwise = at::empty({1}, opts); amax_rowwise = at::empty({1}, opts);
tensor.attr("_amax_rowwise") = *amax_rowwise; tensor.attr("_amax_rowwise") = *amax_rowwise;
} }
...@@ -1392,7 +1398,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor( ...@@ -1392,7 +1398,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
} }
if (!amax_columnwise) { if (!amax_columnwise) {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
amax_columnwise = at::zeros({1}, opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_columnwise = at::empty({1}, opts);
tensor.attr("_amax_columnwise") = *amax_columnwise; tensor.attr("_amax_columnwise") = *amax_columnwise;
} }
} else { // columnwise_usage == false } else { // columnwise_usage == false
......
...@@ -50,8 +50,6 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -50,8 +50,6 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
void* scale_inv_dptr = scale_inv.data_ptr; void* scale_inv_dptr = scale_inv.data_ptr;
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
// Reconstruct input only to avoid swizzling both directions if not needed.
// The specific dtype used is irrelevant, just needs to be correct bits.
transformer_engine::TensorWrapper input_cu(input.scaling_mode()); transformer_engine::TensorWrapper input_cu(input.scaling_mode());
transformer_engine::TensorWrapper output_cu(input.scaling_mode()); transformer_engine::TensorWrapper output_cu(input.scaling_mode());
...@@ -100,10 +98,14 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( ...@@ -100,10 +98,14 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle."); NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING &&
tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) {
return std::nullopt; return std::nullopt;
} }
const auto scaling_mode = tensors.front().scaling_mode();
const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
std::vector<transformer_engine::TensorWrapper> wrappers; std::vector<transformer_engine::TensorWrapper> wrappers;
std::vector<NVTETensor> input_tensors, output_tensors; std::vector<NVTETensor> input_tensors, output_tensors;
...@@ -131,39 +133,44 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( ...@@ -131,39 +133,44 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
// Allocate full buffer // Allocate full buffer
auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));
const auto input_dtype =
(nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
const auto scale_inv_dtype =
(nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto& tensor = tensors[i]; auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i]; void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
auto input_shape = nvte_shape_to_vector(tensor.shape()); // auto input_shape = nvte_shape_to_vector(tensor.shape());
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = tensor.shape();
} else {
nvte_input_shape = tensor.get_columnwise_data().shape;
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
// Reconstruct input only to avoid swizzling both directions if not needed. // Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant. // Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper input_cu(scaling_mode);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper output_cu(scaling_mode);
if (rowwise) { if (rowwise) {
input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
scale_inv_shapes[i]); output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
input_shape); scale_inv_shapes[i]);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor. // Set the swizzled scaling factor to the original tensor.
tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
scale_inv_shapes[i]);
} else { } else {
input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
input_shape); input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
scale_inv_shapes[i]); output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
output_cu.set_columnwise_data(tensor.columnwise_dptr(), scale_inv_shapes[i]);
transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(
swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]);
// Set the swizzled scaling factor to the original tensor. // Set the swizzled scaling factor to the original tensor.
tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); scale_inv_shapes[i]);
} }
input_tensors.emplace_back(input_cu.data()); input_tensors.emplace_back(input_cu.data());
......
...@@ -2,21 +2,21 @@ ...@@ -2,21 +2,21 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""GEMM API for experimental middleware between Transformer Engine and Kitchen.""" """GEMM API that enables custom GEMM logic for custom quantization recipes."""
from typing import Iterable, Optional from typing import Iterable, Optional
import torch import torch
from transformer_engine.pytorch.experimental.quantization import ( from transformer_engine.pytorch.custom_recipes.quantization import (
MMParams, MMParams,
GEMMType, GEMMType,
) )
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.tensor.utils import is_experimental from transformer_engine.pytorch.tensor.utils import is_custom
def experimental_gemm( def custom_gemm(
A: QuantizedTensorStorage, A: QuantizedTensorStorage,
B: QuantizedTensorStorage, B: QuantizedTensorStorage,
workspace: torch.Tensor, # pylint: disable=unused-argument workspace: torch.Tensor, # pylint: disable=unused-argument
...@@ -32,7 +32,7 @@ def experimental_gemm( ...@@ -32,7 +32,7 @@ def experimental_gemm(
grad: bool = False, grad: bool = False,
) -> Iterable[Optional[torch.Tensor]]: ) -> Iterable[Optional[torch.Tensor]]:
"""Dispatch GEMM to quantizer's qgemm method.""" """Dispatch GEMM to quantizer's qgemm method."""
assert is_experimental(A) and is_experimental(B), "A and B must be experimental tensors" assert is_custom(A) and is_custom(B), "A and B must be custom tensors"
A, B = B, A A, B = B, A
......
...@@ -9,9 +9,9 @@ from typing import Optional, Tuple, Union ...@@ -9,9 +9,9 @@ from typing import Optional, Tuple, Union
import torch import torch
from transformer_engine.pytorch.experimental import quantization from transformer_engine.pytorch.custom_recipes import quantization
from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer
def nvfp4_ref_rht_2d_quantizer_factory(role): def nvfp4_ref_rht_2d_quantizer_factory(role):
...@@ -229,8 +229,8 @@ class NVFP4TensorRef(QuantizedTensorStorage): ...@@ -229,8 +229,8 @@ class NVFP4TensorRef(QuantizedTensorStorage):
_quantizer: Optional[Quantizer] = None _quantizer: Optional[Quantizer] = None
@property @property
def experimental(self) -> bool: def custom(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware.""" """Flag to indicate this quantized tensor is custom."""
return True return True
def prepare_for_saving( def prepare_for_saving(
...@@ -362,8 +362,8 @@ class NVFP4QuantizerRef(Quantizer): ...@@ -362,8 +362,8 @@ class NVFP4QuantizerRef(Quantizer):
self.with_random_sign_mask = with_random_sign_mask self.with_random_sign_mask = with_random_sign_mask
@property @property
def experimental(self) -> bool: def custom(self) -> bool:
"""Flag to indicate this quantizer is using experimental Kitchen middleware""" """Flag to indicate this quantizer is custom."""
return True return True
@staticmethod @staticmethod
......
...@@ -29,24 +29,25 @@ except ImportError: ...@@ -29,24 +29,25 @@ except ImportError:
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from . import torch_version from . import torch_version
from .utils import ( from .utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data, safely_set_viewless_tensor_data,
needs_quantized_gemm, needs_quantized_gemm,
) )
from .constants import dist_group_type from .constants import dist_group_type
from .quantization import FP8GlobalStateManager, autocast from .quantization import FP8GlobalStateManager, autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.nvfp4_tensor import NVFP4Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer from .quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer
from .tensor.storage.float8_tensor_storage import Float8TensorStorage from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .triton.pad import pad_columnwise_scale_inv
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
...@@ -1889,6 +1890,43 @@ def allreduce( ...@@ -1889,6 +1890,43 @@ def allreduce(
return inp, handle return inp, handle
def _get_module_fsdp_state(module):
"""
If module is an FSDP module, return its _FSDPState.
Otherwise, return the _FSDPState of the closest parent FSDP module
in the module hierarchy the module belongs to.
"""
if hasattr(module, "_get_fsdp_state"):
# this will return correct fsdp state if module itself is an fsdp module
fsdp_state = module._get_fsdp_state()
elif getattr(module, "_te_cached_parent_fsdp_state", None) is not None:
# See if we have cached the parent fsdp state of the module
fsdp_state = module._te_cached_parent_fsdp_state
else:
from torch.distributed._composable_state import _module_state_mapping
# Otherwise get the fsdp state of lca of module in the module hierarchy
min_nodes_in_parent = float("inf")
closest_parent_fsdp_mod = None
for fsdp_mod in _module_state_mapping.keys():
all_submodules = list(fsdp_mod.modules())
for submodule in all_submodules:
if submodule is module:
if min_nodes_in_parent > len(all_submodules):
closest_parent_fsdp_mod = fsdp_mod
min_nodes_in_parent = len(all_submodules)
if closest_parent_fsdp_mod is None:
raise RuntimeError(
"Module is not FSDP-wrapped and does not have any FSDP-wrapped parent modules."
)
fsdp_state = closest_parent_fsdp_mod._get_fsdp_state()
# Cache the parent fsdp state of the module to avoid recomputing
# the closest parent fsdp module.
module._te_cached_parent_fsdp_state = fsdp_state
return fsdp_state
def _fsdp_scatter_tensors( def _fsdp_scatter_tensors(
fsdp_group: dist_group_type, fsdp_group: dist_group_type,
*tensors: torch.Tensor, *tensors: torch.Tensor,
......
...@@ -322,14 +322,16 @@ def _make_graphed_callables( ...@@ -322,14 +322,16 @@ def _make_graphed_callables(
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
graph_callables = [None for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))]
# For cases with multiple active RNG states, e.g. TP. # For cases with multiple active RNG states, e.g. TP.
if graph_safe_rng_available(): if graph_safe_rng_available():
for _, state in get_all_rng_states().items(): for _, state in get_all_rng_states().items():
for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs): for fwd_graph, bwd_graph, bwd_dw_graph in zip(fwd_graphs, bwd_graphs, bwd_dw_graphs):
fwd_graph.register_generator_state(state) fwd_graph.register_generator_state(state)
bwd_graph.register_generator_state(state) bwd_graph.register_generator_state(state)
bwd_dw_graph.register_generator_state(state)
mempool = graph_pool_handle() if pool is None else pool mempool = graph_pool_handle() if pool is None else pool
...@@ -366,21 +368,8 @@ def _make_graphed_callables( ...@@ -366,21 +368,8 @@ def _make_graphed_callables(
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."
# Filter the TE modules that cudagraph can access. # Filter the TE modules that cudagraph can access.
visited_te_modules = set() visited_te_modules = {}
need_bwd_dw_graph = {}
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
visited_te_modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert module._module_groups is not None, "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
visited_te_modules.add(basic_op)
# Run warmup and do the above filtering. # Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()): with torch.cuda.stream(torch.cuda.Stream()):
...@@ -388,6 +377,31 @@ def _make_graphed_callables( ...@@ -388,6 +377,31 @@ def _make_graphed_callables(
args = sample_args[func_idx] args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx] kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx]
def hook_fn(
module, inputs, outputs, func_idx=func_idx
): # pylint: disable=unused-argument
modules = set()
if isinstance(module, TransformerEngineBaseModule):
modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert (
module._module_groups is not None
), "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
modules.add(basic_op)
if modules:
if func_idx not in visited_te_modules:
visited_te_modules[func_idx] = modules
else:
visited_te_modules[func_idx].update(modules)
for warmup_iter in range(num_warmup_iters): for warmup_iter in range(num_warmup_iters):
hooks = [] hooks = []
for module in func.modules(): for module in func.modules():
...@@ -432,6 +446,15 @@ def _make_graphed_callables( ...@@ -432,6 +446,15 @@ def _make_graphed_callables(
module_params_with_grad module_params_with_grad
) )
per_callable_static_input_surfaces[func_idx] = static_input_surface per_callable_static_input_surfaces[func_idx] = static_input_surface
# Run wgrad. This is essential for some TE modules when they have
# delay_wgrad_compute enabled.
need_backward_dw = False
for module in visited_te_modules.get(func_idx, set()):
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
need_backward_dw = True
module.backward_dw()
need_bwd_dw_graph[func_idx] = need_backward_dw
else: else:
grad_inputs = None grad_inputs = None
del outputs, grad_inputs del outputs, grad_inputs
...@@ -514,6 +537,17 @@ def _make_graphed_callables( ...@@ -514,6 +537,17 @@ def _make_graphed_callables(
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it.
if need_bwd_dw_graph[per_callable_bwd_idx]:
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs # Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern. # that don't require grad. I couldn't think of a one-liner for this pattern.
...@@ -582,10 +616,12 @@ def _make_graphed_callables( ...@@ -582,10 +616,12 @@ def _make_graphed_callables(
# Capture backward graphs in reverse order # Capture backward graphs in reverse order
per_callable_static_grad_outputs = [] per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = [] per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph in zip( for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip(
reversed(per_callable_static_input_surfaces), reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs), reversed(per_callable_static_outputs),
reversed(bwd_graphs), reversed(bwd_graphs),
reversed(bwd_dw_graphs),
reversed(range(len(per_callable_static_input_surfaces))),
): ):
# For now, assumes all static_outputs require grad # For now, assumes all static_outputs require grad
static_grad_outputs = tuple( static_grad_outputs = tuple(
...@@ -601,6 +637,11 @@ def _make_graphed_callables( ...@@ -601,6 +637,11 @@ def _make_graphed_callables(
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
retain_graph=retain_graph_in_backward, retain_graph=retain_graph_in_backward,
) )
if need_bwd_dw_graph[bwd_idx]:
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[bwd_idx]:
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
module.backward_dw()
# Constructs a tuple suitable for returning from Graphed.backward: # Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that # Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern. # don't require grad. I couldn't think of a slick one-liner for this pattern.
...@@ -715,6 +756,21 @@ def _make_graphed_callables( ...@@ -715,6 +756,21 @@ def _make_graphed_callables(
return functionalized return functionalized
def make_graphed_attribute_functions(graph_idx):
# Attach backward_dw as an attribute to the graphed callable.
def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay()
# Attach reset as an attribute to the graphed callable.
def reset():
fwd_graphs[graph_idx].reset()
bwd_graphs[graph_idx].reset()
bwd_dw_graphs[graph_idx].reset()
return backward_dw, reset
# Put together the final graphed callables # Put together the final graphed callables
ret = [] ret = []
for i in range(len(sample_args)): for i in range(len(sample_args)):
...@@ -732,9 +788,10 @@ def _make_graphed_callables( ...@@ -732,9 +788,10 @@ def _make_graphed_callables(
) )
func = graph_callables[i] func = graph_callables[i]
te_modules = visited_te_modules.get(i, set())
if isinstance(func, torch.nn.Module): if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): def make_graphed_forward(func, graph_training_state, graphed, orig_fwd, te_modules):
def new_fwd(*user_args, **user_kwargs): def new_fwd(*user_args, **user_kwargs):
# If the module's training-or-eval state matches what we graphed, # If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method # run the graph, otherwise run the original forward method
...@@ -743,7 +800,7 @@ def _make_graphed_callables( ...@@ -743,7 +800,7 @@ def _make_graphed_callables(
if FP8GlobalStateManager.is_fp8_enabled(): if FP8GlobalStateManager.is_fp8_enabled():
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
for m in func.modules(): for m in func.modules():
if m not in visited_te_modules: if m not in te_modules:
# Only Set the FP8 meta for the modules included by forward # Only Set the FP8 meta for the modules included by forward
continue continue
if isinstance(m, TransformerEngineBaseModule): if isinstance(m, TransformerEngineBaseModule):
...@@ -780,7 +837,7 @@ def _make_graphed_callables( ...@@ -780,7 +837,7 @@ def _make_graphed_callables(
return new_fwd return new_fwd
forward = make_graphed_forward(func, func.training, graphed, func.forward) forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules)
if _order is None: if _order is None:
func.forward = forward func.forward = forward
ret.append(func) ret.append(func)
...@@ -789,6 +846,10 @@ def _make_graphed_callables( ...@@ -789,6 +846,10 @@ def _make_graphed_callables(
else: else:
ret.append(graphed) ret.append(graphed)
backward_dw_func, reset_func = make_graphed_attribute_functions(i)
setattr(ret[-1], "backward_dw", backward_dw_func)
setattr(ret[-1], "reset", reset_func)
if just_one_callable: if just_one_callable:
return ret[0] return ret[0]
......
...@@ -17,6 +17,7 @@ from types import MethodType ...@@ -17,6 +17,7 @@ from types import MethodType
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed.tensor import DTensor
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
...@@ -38,7 +39,7 @@ from ..distributed import ( ...@@ -38,7 +39,7 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
...@@ -707,6 +708,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -707,6 +708,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
self.wgrad_accumulation_and_reduce_hooks = [] self.wgrad_accumulation_and_reduce_hooks = []
self.wgrad_store = None
if not TEDebugState.debug_enabled: if not TEDebugState.debug_enabled:
TEDebugState.initialize() TEDebugState.initialize()
...@@ -1288,7 +1290,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1288,7 +1290,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
metedata used in deferred initialization. metedata used in deferred initialization.
""" """
super().register_parameter(name, param) super().register_parameter(name, param)
self.param_init_meta[name] = _ParameterInitMeta(**kwargs) # Initialize param_init_meta exactly once during the init. FSDP2 can call
# register parameter again to change parameters to DTensors. And it calls
# it without custom fp8 specific kwargs that we need. And so we dont want
# to reset/loose our fp8 init attributes.
if hasattr(self, "param_init_meta") and name not in self.param_init_meta:
self.param_init_meta[name] = _ParameterInitMeta(**kwargs)
def reset_parameters(self, defer_init: Optional[bool] = False) -> None: def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
""" """
...@@ -1300,10 +1307,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1300,10 +1307,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return return
for name, param in self.named_parameters(recurse=False): for name, param in self.named_parameters(recurse=False):
# Check if parameter is a DTensor (FSDP2) or regular tensor
is_dtensor = isinstance(param, DTensor)
dtensor_param = param if is_dtensor else None
# Need to update/quantize local tensor in case of DTensor
param = param._local_tensor if is_dtensor else param
# Ensure parameter is on a real device # Ensure parameter is on a real device
if param.device == torch.device("meta"): if param.device == torch.device("meta"):
param = torch.empty_like(param, device="cuda") param = torch.empty_like(param, device="cuda")
# Initialize the parameter values on device # Initialize the parameter values on device
init_fn = self.param_init_meta[name].init_fn init_fn = self.param_init_meta[name].init_fn
get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
...@@ -1332,7 +1343,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1332,7 +1343,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise RuntimeError("Weight quantizer has not been initialized") raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False quantizer.internal = False
if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
device_mesh = dtensor_param.device_mesh
amax_reduction_group = (
device_mesh.get_group(mesh_dim="shard")
if device_mesh.ndim > 1
else device_mesh.get_group()
)
quantizer.amax_reduction_group = amax_reduction_group
quantizer.with_amax_reduction = True
# Quantize parameter # Quantize parameter
param = quantizer(param) param = quantizer(param)
...@@ -1340,7 +1359,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1340,7 +1359,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# NOTE: Currently this can only be broken when primary weights are in Fp8 but # NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already # re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety. # a parameter so we always re-apply it just for extra safety.
param = torch.nn.Parameter(param) if is_dtensor:
# recreate the DTensor from the parameter.
dtensor_param = DTensor.from_local(
param,
device_mesh=dtensor_param.device_mesh,
placements=dtensor_param.placements,
shape=dtensor_param.size(),
stride=dtensor_param.stride(),
)
dtensor_param = torch.nn.Parameter(dtensor_param)
else:
param = torch.nn.Parameter(param)
# Keep high-precision values on CPU if needed # Keep high-precision values on CPU if needed
if high_precision_init_val is not None: if high_precision_init_val is not None:
...@@ -1368,8 +1398,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1368,8 +1398,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
param._high_precision_init_val = high_precision_init_val param._high_precision_init_val = high_precision_init_val
param.get_high_precision_init_val = MethodType(get, param) param.get_high_precision_init_val = MethodType(get, param)
param.clear_high_precision_init_val = MethodType(clear, param) param.clear_high_precision_init_val = MethodType(clear, param)
# Update the parameter based on its type
setattr(self, name, param) if not is_dtensor:
setattr(self, name, param)
else:
setattr(self, name, dtensor_param)
@abstractmethod @abstractmethod
def forward(self): def forward(self):
...@@ -1526,12 +1560,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1526,12 +1560,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
""" """
self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook) self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)
def need_backward_dw(self):
"""
Check if this module needs to execute the delayed weight gradient computation.
This method should be used at the beginning of self.backward_dw() to determine if it
should actually be executed or just return without doing anything.
User can also manually call this method to check that before calling into backward_dw().
"""
return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients. This method is called after the main backward pass to compute weight gradients.
""" """
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
...@@ -1568,7 +1611,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1568,7 +1611,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug = False debug = False
else: else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration() self.debug_last_iteration = TEDebugState.get_iteration()
self.debug_enabled_in_this_iteration = debug
else:
# If this is the same iteration as previous invocation of the module,
# we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration
return debug return debug
def no_debug_features_active(self, quantizers): def no_debug_features_active(self, quantizers):
......
...@@ -78,7 +78,7 @@ class Fp8Padding(torch.nn.Module): ...@@ -78,7 +78,7 @@ class Fp8Padding(torch.nn.Module):
number of GEMMs to be performed simultaneously. number of GEMMs to be performed simultaneously.
align_size : int, optional align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first be determined by the FP8/FP4 recipe (32 for MXFP8/NVFP4 and 16 for others) in the first
forward pass. forward pass.
""" """
...@@ -111,7 +111,14 @@ class Fp8Padding(torch.nn.Module): ...@@ -111,7 +111,14 @@ class Fp8Padding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None: if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 self.align_size = (
32
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
...@@ -75,9 +75,9 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -75,9 +75,9 @@ class Fp8Unpadding(torch.nn.Module):
num_gemms : int num_gemms : int
number of GEMMs to be performed simultaneously. number of GEMMs to be performed simultaneously.
align_size : int, optional align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will The alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first be automatically determined based on the FP8/FP4 recipe in the first forward pass:
forward pass. 32 for MXFP8 or NVFP4, otherwise 16.
""" """
def __init__( def __init__(
...@@ -109,7 +109,14 @@ class Fp8Unpadding(torch.nn.Module): ...@@ -109,7 +109,14 @@ class Fp8Unpadding(torch.nn.Module):
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if self.align_size is None: if self.align_size is None:
self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 self.align_size = (
32
if (
FP8GlobalStateManager.get_fp8_recipe().mxfp8()
or FP8GlobalStateManager.get_fp8_recipe().nvfp4()
)
else 16
)
# FP8 padding calculate # FP8 padding calculate
padded_m_splits = [ padded_m_splits = [
......
...@@ -14,6 +14,7 @@ import transformer_engine_torch as tex ...@@ -14,6 +14,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_dummy_wgrad,
get_multi_stream_cublas_workspace, get_multi_stream_cublas_workspace,
get_dummy_wgrad, get_dummy_wgrad,
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -42,10 +43,10 @@ from ..cpp_extensions import ( ...@@ -42,10 +43,10 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import ( from ..quantized_tensor import (
QuantizedTensorStorage, QuantizedTensorStorage,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
...@@ -111,9 +112,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -111,9 +112,15 @@ class _GroupedLinear(torch.autograd.Function):
is_fp8_activation_recompute_enabled() is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
) )
if weight_quantizers[0] is not None: # No need to set the quantizer states if weight is already quantized
if weight_quantizers[0] is not None and not isinstance(
weights[0], QuantizedTensorStorage
):
for weight_quantizer in weight_quantizers: for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weights[0], QuantizedTensorStorage):
# If weights are already quantized, no need to set quantizer states
weight_quantizers = [weight._quantizer for weight in weights]
if output_quantizers[0] is not None: if output_quantizers[0] is not None:
for output_quantizer in output_quantizers: for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False) output_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -132,6 +139,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -132,6 +139,9 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
if cpu_offloading:
start_offload(*inputmats)
# Initialize weights # Initialize weights
weights_fp8: list weights_fp8: list
if fp8: if fp8:
...@@ -193,6 +203,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -193,6 +203,9 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(num_gemms): for i in range(num_gemms):
weight_quantizers[i].calibrate(weights[i]) weight_quantizers[i].calibrate(weights[i])
if cpu_offloading:
mark_not_offload(*weights_fp8, *weights)
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1] ctx.weights_shape_1 = weights[0].shape[1]
...@@ -208,10 +221,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -208,10 +221,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
inputmats = [None] * num_gemms inputmats = [None] * num_gemms
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
for i in range(num_gemms): for i in range(num_gemms):
weights[i].offloading_activation = False weights[i].offloading_activation = False
...@@ -322,9 +331,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -322,9 +331,9 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.fine_grained_activation_offloading: if ctx.fine_grained_activation_offloading:
origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
for i in range(N): for i in range(N):
origin_weights[i].main_grad = main_grads[i] origin_weights[i].main_grad = main_grads[i]
# Preprocess grad output # Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
...@@ -385,13 +394,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -385,13 +394,11 @@ class _GroupedLinear(torch.autograd.Function):
dtype=ctx.activation_dtype, dtype=ctx.activation_dtype,
device=ctx.device, device=ctx.device,
) )
# Make sure weights are available in column-wise format
for weight, quantizer in zip(weights, ctx.weight_quantizers): # for dgrad computation.
if quantizer is not None and isinstance(weight, QuantizedTensorStorage): for weight in weights:
weight.update_usage( if isinstance(weight, QuantizedTensorStorage):
rowwise_usage=quantizer.rowwise_usage, weight.update_usage(columnwise_usage=True)
columnwise_usage=quantizer.columnwise_usage,
)
general_grouped_gemm( general_grouped_gemm(
weights, weights,
grad_output, grad_output,
...@@ -880,7 +887,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -880,7 +887,7 @@ class GroupedLinear(TransformerEngineBaseModule):
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients. This method is called after the main backward pass to compute weight gradients.
""" """
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): if not self.need_backward_dw():
return return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop() (_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment