Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c1a1c04e
Commit
c1a1c04e
authored
Dec 27, 2025
by
wenjh
Browse files
Merge nv_main(2.10) to main
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
e698a0a7
66aed3ae
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1294 additions
and
131 deletions
+1294
-131
transformer_engine/pytorch/cpu_offload_v1.py
transformer_engine/pytorch/cpu_offload_v1.py
+751
-0
transformer_engine/pytorch/csrc/common.cpp
transformer_engine/pytorch/csrc/common.cpp
+3
-2
transformer_engine/pytorch/csrc/extensions.h
transformer_engine/pytorch/csrc/extensions.h
+4
-3
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
+12
-5
transformer_engine/pytorch/csrc/extensions/attention.cpp
transformer_engine/pytorch/csrc/extensions/attention.cpp
+32
-25
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+212
-1
transformer_engine/pytorch/csrc/extensions/recipe.cpp
transformer_engine/pytorch/csrc/extensions/recipe.cpp
+2
-1
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+9
-1
transformer_engine/pytorch/csrc/util.cpp
transformer_engine/pytorch/csrc/util.cpp
+33
-26
transformer_engine/pytorch/custom_recipes/__init__.py
transformer_engine/pytorch/custom_recipes/__init__.py
+0
-0
transformer_engine/pytorch/custom_recipes/gemm.py
transformer_engine/pytorch/custom_recipes/gemm.py
+6
-6
transformer_engine/pytorch/custom_recipes/quantization.py
transformer_engine/pytorch/custom_recipes/quantization.py
+0
-0
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py
...ormer_engine/pytorch/custom_recipes/quantization_nvfp4.py
+7
-7
transformer_engine/pytorch/custom_recipes/utils.py
transformer_engine/pytorch/custom_recipes/utils.py
+0
-0
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+40
-2
transformer_engine/pytorch/graph.py
transformer_engine/pytorch/graph.py
+81
-20
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+57
-8
transformer_engine/pytorch/module/fp8_padding.py
transformer_engine/pytorch/module/fp8_padding.py
+9
-2
transformer_engine/pytorch/module/fp8_unpadding.py
transformer_engine/pytorch/module/fp8_unpadding.py
+11
-4
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+25
-18
No files found.
transformer_engine/pytorch/cpu_offload_v1.py
0 → 100644
View file @
c1a1c04e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Functionality for CPU offloading of tensors saved for backward pass."""
from
__future__
import
annotations
from
contextlib
import
nullcontext
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
.quantized_tensor
import
QuantizedTensorStorage
from
.tensor.float8_tensor
import
Float8Tensor
__all__
=
[
"get_cpu_offload_context"
]
CPUOffloadEnabled
=
False
CPUOffloadedLayer
=
False
def
get_cpu_offloading
():
global
CPUOffloadEnabled
return
CPUOffloadEnabled
def
set_cpu_offloading
(
cpu_offloading
):
global
CPUOffloadEnabled
CPUOffloadEnabled
=
cpu_offloading
def
mark_activation_offload
(
*
tensors
):
"""Set the type of the offloading needed for a tensor."""
if
TEDebugState
.
debug_enabled
:
raise
RuntimeError
(
"CPU offload is not supported in debug mode."
)
for
tensor
in
tensors
:
if
tensor
is
None
:
continue
if
type
(
tensor
)
in
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]:
tensor
.
activation_offloading
=
True
else
:
data_tensors
=
tensor
.
get_data_tensors
()
for
tensor
in
data_tensors
:
if
tensor
is
not
None
:
tensor
.
activation_offloading
=
True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor
.
needs_force_clear
=
True
def
is_cpu_offload_enabled
()
->
bool
:
"""Check if CPU offloading is currently enabled."""
return
CPUOffloadEnabled
def
is_current_layer_offloaded
()
->
bool
:
"""Check if current layers is being offloaded."""
return
CPUOffloadedLayer
class
CpuOffloadSavedTensorHook
:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
In this context, the ``on_save_for_backward`` method will be called every time
a tensor is saved for backward (this includes intermediary results saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation).
The ``on_get_saved_tensors`` method will be called when the backward function
of this op attempts to retrieve the saved tensor from context (this includes
:func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the
as input the return value of the ``on_save_for_backward``, and is meant to return
an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of
size, device and element values.
Example:
>>> import torch
>>> from typing import Any
>>>
>>> class DummyHook(CpuOffloadSavedTensorHook):
...
... def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
... logging.info("On save", tensor)
... return (tensor,)
...
... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
... logging.info("On get", saved_state)
... tensor, = saved_state
... return tensor
...
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with DummyHook():
... y = a * b
...
On save tensor([1., 1., 1., 1., 1.], requires_grad=True)
On save tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),)
On get (tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>),)
"""
def
__init__
(
self
)
->
None
:
self
.
inside_context
=
False
def
__enter__
(
self
):
global
CPUOffloadEnabled
CPUOffloadEnabled
=
True
self
.
inside_context
=
True
torch
.
_C
.
_autograd
.
_push_saved_tensors_default_hooks
(
self
.
on_save_for_backward
,
self
.
on_get_saved_tensor
)
def
__exit__
(
self
,
*
args
:
Any
):
global
CPUOffloadEnabled
CPUOffloadEnabled
=
False
self
.
inside_context
=
False
torch
.
_C
.
_autograd
.
_pop_saved_tensors_default_hooks
()
def
on_save_for_backward
(
self
,
tensor
:
torch
.
Tensor
)
->
Any
:
"""On save for backward."""
raise
NotImplementedError
(
"`on_save_for_backward: Callable[[torch.Tensor], Any]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
def
on_get_saved_tensor
(
self
,
saved_state
:
Any
)
->
torch
.
Tensor
:
"""On get saved tensor."""
raise
NotImplementedError
(
"`on_get_saved_tensors: Callable[[Any], torch.Tensor]`"
"is not implemented in CpuOffloadHook class. Inherit "
"this class and implement your custom hooks"
)
class
CpuOffloadHookWithOffloadHandler
(
CpuOffloadSavedTensorHook
):
"""Context-manager that offloads/recovers tensors through an offload hander.
The hook just offloads/recovers the tensor object to the handler through `tensor_push`
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook.
"""
def
__init__
(
self
,
offload_handler
:
OffloadHandler
,
handler_extra_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
debug
:
bool
=
False
,
)
->
None
:
if
handler_extra_kwargs
is
None
:
handler_extra_kwargs
=
{}
self
.
debug
:
bool
=
debug
self
.
offload_handler
:
OffloadHandler
=
offload_handler
self
.
handler_extra_kwargs
:
Dict
[
str
,
Any
]
=
handler_extra_kwargs
super
().
__init__
()
def
on_save_for_backward
(
self
,
tensor
:
torch
.
Tensor
)
->
Any
:
retrieve_identifier
=
self
.
offload_handler
.
tensor_push
(
tensor
,
**
self
.
handler_extra_kwargs
)
return
retrieve_identifier
def
on_get_saved_tensor
(
self
,
saved_state
:
Any
)
->
torch
.
Tensor
:
tensor
=
self
.
offload_handler
.
tensor_pop
(
saved_state
,
**
self
.
handler_extra_kwargs
)
return
tensor
class
OffloadHandler
:
"""A base class for CPU offload-handler."""
def
__init__
(
self
)
->
None
:
pass
def
tensor_push
(
self
,
tensor
:
torch
.
Tensor
,
**
kwargs
)
->
Any
:
"""Tensor push."""
raise
NotImplementedError
(
"`tensor_push is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_push."
)
def
tensor_pop
(
self
,
tensor_tag
:
Any
,
**
kwargs
):
"""Tensor pop."""
raise
NotImplementedError
(
"`tensor_pop is not implented in OffloadHandler class. "
"Inherit this class and implement your custom tensor_pop."
)
class
GroupCommitFunction
(
torch
.
autograd
.
Function
):
"""this is a dummy op with output identical to input.
However, it is necessary for marking a timepoint for offload handler to
accomplish all synchronizations. Implementing it as a function is necessary
because we need to actions in both forward and backward.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
,
cpu_offload_handler
):
# pylint: disable=missing-function-docstring
cpu_offload_handler
.
on_group_commit_forward
()
ctx
.
cpu_offload_handler
=
cpu_offload_handler
# return the identical tensor
return
tensor
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# pylint: disable=missing-function-docstring
cpu_offload_handler
=
ctx
.
cpu_offload_handler
cpu_offload_handler
.
on_group_commit_backward
()
return
grad_output
,
None
group_prefetch_offload_commit
=
GroupCommitFunction
.
apply
class
SynchronizedGroupOffloadHandler
(
OffloadHandler
):
"""Offload Handler that offloads/reloads in a synchronized way.
The device-to-host and host-to-device copying happen in the same stream
as the computation kernels, thus the copying will block computation.
"""
def
__init__
(
self
,
num_offload_group
,
tensor_need_offloading_checker
=
(
lambda
_
:
True
),
debug
=
False
)
->
None
:
super
().
__init__
()
self
.
num_offload_group
=
num_offload_group
self
.
tensor_need_offloading_checker
=
tensor_need_offloading_checker
self
.
debug
=
debug
self
.
groupid_reset
()
def
groupid_reset
(
self
):
"""Groupid reset."""
# Data structures to label saved tensors and book-keep their cpu copies.
# Currently, on push, create a new cpu tensor and copies; on pop, copies
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self
.
current_group
,
self
.
tensor_count_current_group
=
(
0
,
0
)
self
.
torch_tensor_count
=
0
self
.
tensor_tag_to_state
=
{}
def
on_group_commit_forward
(
self
):
"""On group commit forward."""
# finishing up with updating current group and tensor count
self
.
current_group
+=
1
# increment
self
.
tensor_count_current_group
=
0
# reset
def
on_group_commit_backward
(
self
):
"""On group commit backward."""
self
.
current_group
-=
1
assert
self
.
current_group
>=
0
@
staticmethod
def
offload
(
src_tensor
,
pin_memory
=
True
):
"""Offload."""
cpu_backup
=
torch
.
empty
(
src_tensor
.
size
(),
dtype
=
src_tensor
.
dtype
,
layout
=
src_tensor
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_backup
.
copy_
(
src_tensor
,
non_blocking
=
pin_memory
)
state
=
(
src_tensor
.
device
,
cpu_backup
)
return
state
@
staticmethod
def
reload
(
state
,
non_blocking
=
None
,
copy_buffer
=
None
):
"""Reload."""
dev
,
cpu_backup
=
state
if
non_blocking
is
None
:
non_blocking
=
cpu_backup
.
is_pinned
()
if
copy_buffer
is
None
:
return
cpu_backup
.
to
(
dev
,
non_blocking
=
non_blocking
)
assert
cpu_backup
.
size
()
==
copy_buffer
.
size
(),
"Can't copy two buffers of different sizes!"
copy_buffer
.
copy_
(
cpu_backup
,
non_blocking
=
non_blocking
)
return
copy_buffer
def
tensor_push
(
self
,
tensor
:
torch
.
Tensor
,
**
kwargs
):
"""Tensor push."""
# obtain a unique tensor tag
tensor_tag
=
(
self
.
current_group
,
self
.
tensor_count_current_group
)
self
.
tensor_count_current_group
+=
1
assert
tensor_tag
not
in
self
.
tensor_tag_to_state
if
self
.
current_group
<
self
.
num_offload_group
and
self
.
tensor_need_offloading_checker
(
tensor
):
state
=
SynchronizedGroupOffloadHandler
.
offload
(
tensor
)
self
.
tensor_tag_to_state
[
tensor_tag
]
=
state
else
:
# will be offloaded together after group commit
self
.
tensor_tag_to_state
[
tensor_tag
]
=
tensor
return
tensor_tag
def
tensor_pop
(
self
,
tensor_tag
,
**
kwargs
):
"""Tensor pop."""
assert
tensor_tag
in
self
.
tensor_tag_to_state
state
=
self
.
tensor_tag_to_state
.
pop
(
tensor_tag
)
if
isinstance
(
state
,
tuple
):
tensor
=
SynchronizedGroupOffloadHandler
.
reload
(
state
)
else
:
tensor
=
state
return
tensor
class
AsyncDoubleBufferGroupOffloadHandler
(
SynchronizedGroupOffloadHandler
):
"""Compared to synchronize, this uses more memory because of the buffer but
achieves better performance due to the overlapping. D2h and h2d copying are
completely hidden behind computation if computation time of a layer is longer
than host-device communication time. Bulk offloading with delay and bulk reloading
with prefetch are implemented."""
def
__init__
(
self
,
num_offload_group
,
# must be <= actual number of groups (number of commits)
num_model_group
,
tensor_need_offloading_checker
=
(
lambda
t
:
True
),
double_buffering
=
False
,
debug
=
False
,
)
->
None
:
super
().
__init__
(
num_offload_group
=
num_offload_group
,
tensor_need_offloading_checker
=
tensor_need_offloading_checker
,
debug
=
debug
,
)
# Number of layers in the model
self
.
num_layers
=
num_model_group
# Data Structure to maintain reference to activation tensors
self
.
tensor_tag_to_buf
=
{}
# Data structure to hold the FP8/MXFP8 tensor objects
self
.
fp8_tensor_object_map
=
{}
self
.
float8_transpose_cache_valid
=
{}
self
.
dereferencing_list
=
[]
# Tracking the number of layers offloaded
self
.
offloaded_group_count
=
0
# Core data structure that decides the window for offloading
self
.
layer_window_map
=
{}
# Data structures fo double buffered reloading
self
.
double_buffering
=
double_buffering
self
.
reload_double_buffer
=
[[],
[]]
self
.
double_buffer_created
=
False
# Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage
constant
=
0
for
i
in
range
(
self
.
num_offload_group
):
self
.
layer_window_map
[
i
]
=
((
self
.
num_layers
//
self
.
num_offload_group
)
*
(
i
+
1
))
-
1
if
i
<
(
self
.
num_layers
%
self
.
num_offload_group
):
self
.
layer_window_map
[
i
]
+=
i
+
1
constant
=
i
+
1
else
:
self
.
layer_window_map
[
i
]
+=
constant
# allocate streams and events for synchronization
self
.
d2h_stream
=
torch
.
cuda
.
Stream
()
self
.
h2d_stream
=
torch
.
cuda
.
Stream
()
def
tensor_push
(
self
,
tensor
:
torch
.
Tensor
,
**
kwargs
)
->
Any
:
global
CPUOffloadedLayer
torch_stray_tensor
=
isinstance
(
tensor
,
(
torch
.
_subclasses
.
fake_tensor
.
FakeTensor
,
torch
.
_subclasses
.
functional_tensor
.
FunctionalTensor
,
),
)
is_quantized_tensor
=
isinstance
(
tensor
,
QuantizedTensorStorage
)
if
not
torch_stray_tensor
:
# obtain a unique tensor tag
tensor_tag
=
(
self
.
current_group
,
self
.
tensor_count_current_group
)
self
.
tensor_count_current_group
+=
1
assert
tensor_tag
not
in
self
.
tensor_tag_to_state
if
is_quantized_tensor
:
tensor_list
,
_
=
tensor
.
prepare_for_saving
()
self
.
tensor_tag_to_state
[
tensor_tag
]
=
[]
self
.
tensor_tag_to_buf
[
tensor_tag
]
=
[]
# Added support for de-duplicating FP8 param tensors
for
_
,
value
in
self
.
fp8_tensor_object_map
.
items
():
if
tensor
is
value
:
self
.
dereferencing_list
.
append
(
tensor_tag
)
break
self
.
fp8_tensor_object_map
[
tensor_tag
]
=
tensor
if
isinstance
(
tensor
,
Float8Tensor
):
self
.
float8_transpose_cache_valid
[
tensor_tag
]
=
getattr
(
tensor
,
"_transpose_invalid"
)
else
:
tensor_list
=
[
tensor
]
for
t
in
tensor_list
:
if
is_quantized_tensor
:
self
.
tensor_tag_to_state
[
tensor_tag
].
append
(
t
)
else
:
self
.
tensor_tag_to_state
[
tensor_tag
]
=
t
if
(
self
.
current_group
<
self
.
num_offload_group
and
self
.
tensor_need_offloading_checker
(
t
)
):
if
is_quantized_tensor
:
self
.
tensor_tag_to_buf
[
tensor_tag
].
append
(
t
)
# Need to clear the internal data reference for the quantized tensors
tensor
.
clear
()
else
:
self
.
tensor_tag_to_buf
[
tensor_tag
]
=
t
# Needed to differentiate non offloaded layer's attention
# QKV layout of attention of non-offloaded layer needs
# to be modified while reloading
CPUOffloadedLayer
=
True
else
:
tensor_tag
=
(
-
1
,
self
.
torch_tensor_count
)
self
.
torch_tensor_count
+=
1
self
.
tensor_tag_to_state
[
tensor_tag
]
=
tensor
return
tensor_tag
def
tensor_pop
(
self
,
tensor_tag
,
**
kwargs
):
"""Tensor pop."""
global
CPUOffloadedLayer
assert
tensor_tag
in
self
.
tensor_tag_to_state
tensor
=
self
.
tensor_tag_to_state
.
pop
(
tensor_tag
)
# Handling the quantized tensor case specially here
if
isinstance
(
tensor
,
list
):
# If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if
tensor_tag
in
self
.
dereferencing_list
:
self
.
dereferencing_list
.
remove
(
tensor_tag
)
else
:
self
.
fp8_tensor_object_map
[
tensor_tag
].
restore_from_saved
(
tensor
)
tensor
=
self
.
fp8_tensor_object_map
.
pop
(
tensor_tag
)
if
self
.
double_buffering
:
tensor
.
_do_not_clear
=
True
self
.
tensor_tag_to_buf
.
pop
(
tensor_tag
,
None
)
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
assert
not
isinstance
(
tensor
,
tuple
)
return
tensor
def
bulk_offload_group
(
self
,
group_to_offload
):
"""Bulk offload group."""
with
torch
.
cuda
.
stream
(
self
.
d2h_stream
):
for
tensor_tag
,
state
in
self
.
tensor_tag_to_state
.
items
():
group_id
,
_
=
tensor_tag
if
group_id
==
group_to_offload
:
assert
not
isinstance
(
state
,
tuple
)
is_quantized_tensor
=
isinstance
(
state
,
list
)
if
is_quantized_tensor
:
tensor_list
=
state
self
.
tensor_tag_to_state
[
tensor_tag
]
=
[]
else
:
tensor_list
=
[
state
]
for
tensor_on_device
in
tensor_list
:
# `tensor_offloaded` is a hacky way of dealing with columnwise-only
# quantized tensors for CPU offloading. The complication is due to
# the `rowwise_data` being `None`. The offloading checker incorrectly
# returns `False` and the entire `state` ([None, columnwise_tensor])
# is added to the tensor tag state dict. A better design would change
# how quantized tensors are kept track of in the offload handler.
# Currently at every stage it is ensured that a quantized tensor is a
# list whereas a non-quantized tensor is standalone object, which is
# not good! TODO(@sanandaraj5597)
tensor_offloaded
=
False
# if offload, return the reference to cpu copy
if
self
.
tensor_need_offloading_checker
(
tensor_on_device
):
tensor_offloaded
=
True
state
=
SynchronizedGroupOffloadHandler
.
offload
(
tensor_on_device
)
if
is_quantized_tensor
:
if
tensor_offloaded
:
self
.
tensor_tag_to_state
[
tensor_tag
].
append
(
state
)
else
:
self
.
tensor_tag_to_state
[
tensor_tag
].
append
(
tensor_on_device
)
else
:
self
.
tensor_tag_to_state
[
tensor_tag
]
=
state
def
synchronize_on_group_commit_forward
(
self
,
current_group
):
"""Synchronize on group commit forward."""
global
CPUOffloadedLayer
# For the first group, kickstart the offload after we have
# the first compute completion
if
current_group
==
0
:
self
.
d2h_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
if
not
self
.
double_buffer_created
:
# Creating the first copy of double buffer for tensors that are offloaded
for
tensor_tag
,
buf
in
self
.
tensor_tag_to_buf
.
items
():
if
isinstance
(
buf
,
list
):
for
b
in
buf
:
self
.
reload_double_buffer
[
0
].
append
(
torch
.
empty_like
(
b
)
if
self
.
double_buffering
else
None
)
else
:
self
.
reload_double_buffer
[
0
].
append
(
torch
.
empty_like
(
buf
)
if
self
.
double_buffering
else
None
)
self
.
bulk_offload_group
(
current_group
)
# Window map data structure helps us synchronize based on number
# of layers offloaded
if
self
.
layer_window_map
[
self
.
offloaded_group_count
]
==
current_group
:
# Stream synchronization both ways
self
.
d2h_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
d2h_stream
)
# Time to free the activation memory after usage
for
tensor_tag
,
tensor_buf
in
self
.
tensor_tag_to_buf
.
items
():
if
tensor_tag
[
0
]
==
self
.
offloaded_group_count
:
if
hasattr
(
tensor_buf
,
"needs_force_clear"
):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf
.
data
=
torch
.
Tensor
()
# Release the pointer to the tensor
self
.
tensor_tag_to_buf
[
tensor_tag
]
=
None
# Time to offload the next group
if
self
.
offloaded_group_count
<
(
self
.
num_offload_group
-
1
):
self
.
bulk_offload_group
(
self
.
offloaded_group_count
+
1
)
# Increment the offload group count to keep track
self
.
offloaded_group_count
+=
1
if
current_group
==
(
self
.
num_offload_group
-
1
):
CPUOffloadedLayer
=
False
if
not
self
.
double_buffer_created
:
# Creating second copy of double buffer for tensors that are offloaded
if
current_group
==
(
self
.
num_layers
-
1
):
for
buf
in
self
.
reload_double_buffer
[
0
]:
self
.
reload_double_buffer
[
1
].
append
(
torch
.
empty_like
(
buf
)
if
self
.
double_buffering
else
None
)
self
.
double_buffer_created
=
True
def
on_group_commit_forward
(
self
):
"""This function will cause host device synchronization"""
# handle synchronization events
self
.
synchronize_on_group_commit_forward
(
self
.
current_group
)
super
().
on_group_commit_forward
()
def
bulk_reload_group
(
self
,
group_to_reload
):
"""Bulk reload group."""
assert
group_to_reload
<
self
.
num_offload_group
buffer_idx
=
0
double_buffer_idx
=
group_to_reload
%
2
main_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
h2d_stream
):
# move back tensors
for
tensor_label
,
state
in
self
.
tensor_tag_to_state
.
items
():
group_id
,
_
=
tensor_label
if
group_id
==
group_to_reload
:
if
isinstance
(
state
,
tuple
):
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
with
torch
.
cuda
.
stream
(
main_stream
):
reload_buffer
=
torch
.
empty_like
(
state
[
1
],
device
=
torch
.
cuda
.
current_device
()
)
recovered_tensor
=
SynchronizedGroupOffloadHandler
.
reload
(
state
,
True
,
reload_buffer
)
buffer_idx
=
buffer_idx
+
1
self
.
tensor_tag_to_state
[
tensor_label
]
=
recovered_tensor
elif
isinstance
(
state
,
list
):
tensor_list
=
[]
for
state_tuple
in
state
:
if
isinstance
(
state_tuple
,
tuple
):
if
self
.
double_buffering
:
reload_buffer
=
self
.
reload_double_buffer
[
double_buffer_idx
][
buffer_idx
]
else
:
with
torch
.
cuda
.
stream
(
main_stream
):
reload_buffer
=
torch
.
empty_like
(
state_tuple
[
1
],
device
=
torch
.
cuda
.
current_device
()
)
tensor_list
.
append
(
SynchronizedGroupOffloadHandler
.
reload
(
state_tuple
,
True
,
reload_buffer
,
)
)
buffer_idx
=
buffer_idx
+
1
else
:
tensor_list
.
append
(
state_tuple
)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if
tensor_label
in
self
.
dereferencing_list
:
self
.
dereferencing_list
.
remove
(
tensor_label
)
else
:
_
=
self
.
fp8_tensor_object_map
[
tensor_label
].
restore_from_saved
(
tensor_list
)
if
isinstance
(
self
.
fp8_tensor_object_map
[
tensor_label
],
Float8Tensor
):
self
.
fp8_tensor_object_map
[
tensor_label
].
_transpose_invalid
=
(
self
.
float8_transpose_cache_valid
.
pop
(
tensor_label
)
)
self
.
tensor_tag_to_state
[
tensor_label
]
=
self
.
fp8_tensor_object_map
.
pop
(
tensor_label
)
def
on_group_commit_backward
(
self
):
# first decrement the current group.
# after last commit in forward, the group will +1; in backward it -1.
# Finally it should be decremented to 0.
self
.
current_group
-=
1
assert
self
.
current_group
>=
0
# Layer window data structure helps us to reload at right times
if
self
.
layer_window_map
[
self
.
offloaded_group_count
-
1
]
==
self
.
current_group
:
# Stream synchronization both ways
self
.
h2d_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
h2d_stream
)
# Time to reload the next group
self
.
bulk_reload_group
(
self
.
offloaded_group_count
-
1
)
# Decrease the offloading group counter
self
.
offloaded_group_count
-=
1
if
self
.
offloaded_group_count
>
1
else
0
# Last group computation needs to wait till all the reloads complete
if
self
.
current_group
==
0
:
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
h2d_stream
)
self
.
offloaded_group_count
=
0
def
get_cpu_offload_context
(
enabled
:
bool
=
False
,
num_layers
:
int
=
1
,
model_layers
:
int
=
1
,
offload_activations
:
bool
=
True
,
offload_weights
:
bool
=
False
,
double_buffering
:
bool
=
False
,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
used after every transformer layer. Returns `nullcontext()` if offloading is not enabled.
Usage:
.. code-block:: python
cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True)
with cpu_offload_context:
te_layer.forward(inp_tensor)
cpu_offload_synchronizer()
Parameters
----------
enabled: bool, default = `False`
When set to True, CPU Offloading functionality is enabled.
num_layers: int, default = 1
Determines the number of transformer layers
you want to offload activations/weights for.
model_layers: int, default = 1
Number of layers in the model that will be used under this context.
offload_activations: bool, default = `True`
When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading.
"""
if
not
offload_weights
and
not
offload_activations
:
raise
ValueError
(
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)"
)
if
offload_weights
:
import
warnings
warnings
.
warn
(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect."
,
DeprecationWarning
,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if
not
offload_activations
:
return
nullcontext
(),
lambda
x
:
x
def
tensor_need_offloading_checker_activations
(
tensor
):
return
hasattr
(
tensor
,
"activation_offloading"
)
tensor_need_offloading_checker
=
tensor_need_offloading_checker_activations
cpu_offload_handler
=
AsyncDoubleBufferGroupOffloadHandler
(
num_offload_group
=
num_layers
,
num_model_group
=
model_layers
,
tensor_need_offloading_checker
=
tensor_need_offloading_checker
,
double_buffering
=
double_buffering
,
)
def
group_prefetch_offload_commit_async
(
tensor
):
return
group_prefetch_offload_commit
(
tensor
,
cpu_offload_handler
)
if
enabled
:
return
(
CpuOffloadHookWithOffloadHandler
(
offload_handler
=
cpu_offload_handler
),
group_prefetch_offload_commit_async
,
)
return
nullcontext
(),
group_prefetch_offload_commit_async
transformer_engine/pytorch/csrc/common.cpp
View file @
c1a1c04e
...
...
@@ -190,8 +190,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
const
std
::
vector
<
size_t
>
meta_shape
{
1
};
ret
.
set_amax
(
amax_ptr
,
DType
::
kFloat32
,
meta_shape
);
ret
.
set_scale
(
scale_ptr
,
DType
::
kFloat32
,
meta_shape
);
auto
scale_inv_dtype
=
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
?
DType
::
kFloat8E8M0
:
DType
::
kFloat32
;
auto
scale_inv_dtype
=
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
?
DType
::
kFloat8E8M0
:
(
scaling_mode
==
NVTE_NVFP4_1D_SCALING
)
?
DType
::
kFloat8E4M3
:
DType
::
kFloat32
;
ret
.
set_rowwise_scale_inv
(
scale_inv_ptr
,
scale_inv_dtype
,
scale_inv_shape
);
ret
.
set_columnwise_scale_inv
(
columnwise_scale_inv_ptr
,
scale_inv_dtype
,
columnwise_scale_inv_shape
);
...
...
transformer_engine/pytorch/csrc/extensions.h
View file @
c1a1c04e
...
...
@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
);
std
::
pair
<
TensorWrapper
,
py
::
object
>
quantizer_helper
(
py
::
handle
quantizer
,
const
std
::
vector
<
size_t
>
&
shape
,
DType
dtype
,
...
...
@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd(
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
const
std
::
optional
<
at
::
Tensor
>
SoftmaxOffset
,
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
);
size_t
rng_elts_per_thread
,
bool
return_max_logit
,
bool
cuda_graph
);
std
::
vector
<
py
::
object
>
fused_attn_bwd
(
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
p_dropout
,
bool
set_zero
,
...
...
@@ -106,7 +106,7 @@ std::vector<py::object> fused_attn_bwd(
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
);
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
,
bool
cuda_graph
);
at
::
Tensor
fa_prepare_fwd
(
at
::
Tensor
qkvi
);
at
::
Tensor
fa_prepare_bwd
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
v
);
...
...
@@ -384,6 +384,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const
int
cp_rank
);
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
const
std
::
optional
<
at
::
Tensor
>
start_positions
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
);
...
...
transformer_engine/pytorch/csrc/extensions/apply_rope.cpp
View file @
c1a1c04e
...
...
@@ -163,6 +163,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
}
at
::
Tensor
fused_rope_backward
(
const
at
::
Tensor
&
output_grads
,
const
at
::
Tensor
&
freqs
,
const
std
::
optional
<
at
::
Tensor
>
start_positions
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens
,
const
int
cp_size
,
const
int
cp_rank
)
{
...
...
@@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto
freqs_cu
=
makeTransformerEngineTensor
(
freqs
);
auto
input_grads_cu
=
makeTransformerEngineTensor
(
input_grads
);
auto
start_positions_cu
=
TensorWrapper
();
// empty start_positions tensor
if
(
start_positions
)
{
start_positions_cu
=
makeTransformerEngineTensor
(
start_positions
.
value
());
TORCH_CHECK
(
start_positions_cu
.
ndim
()
==
1
,
"expected 1D tensor"
);
}
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
TORCH_CHECK
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
TORCH_CHECK
(
cu_seqlens
.
has_value
(),
"expected cu_seqlens tensor"
);
...
...
@@ -208,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto
cu_seqlens_cu
=
makeTransformerEngineTensor
(
cu_seqlens
.
value
());
nvte_fused_rope_backward
(
output_grads_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
input_grads_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
start_positions_cu
.
data
(),
input_grads_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
/*stride_b=*/
0
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
input_grads
;
...
...
@@ -246,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto
cu_seqlens_cu
=
TensorWrapper
();
// empty cu_seqlens tensor
nvte_fused_rope_backward
(
output_grads_cu
.
data
(),
cu_seqlens_cu
.
data
(),
freqs_cu
.
data
(),
input_grads_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
start_positions_cu
.
data
(),
input_grads_cu
.
data
(),
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
at
::
cuda
::
getCurrentCUDAStream
());
return
input_grads
;
}
...
...
transformer_engine/pytorch/csrc/extensions/attention.cpp
View file @
c1a1c04e
...
...
@@ -45,14 +45,15 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
p_dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
)
{
#ifdef __HIP_PLATFORM_AMD__
return
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
#else
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
p_dropout
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
);
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
window_size_left
,
window_size_right
,
return_max_logit
,
cuda_graph
);
return
fused_attention_backend
;
#endif
}
...
...
@@ -110,7 +111,7 @@ std::vector<py::object> fused_attn_fwd(
const
std
::
optional
<
at
::
Tensor
>
page_table_k
,
const
std
::
optional
<
at
::
Tensor
>
page_table_v
,
py
::
handle
s_quantizer
,
py
::
handle
o_quantizer
,
const
std
::
optional
<
at
::
Tensor
>
Bias
,
const
std
::
optional
<
at
::
Tensor
>
SoftmaxOffset
,
const
std
::
optional
<
at
::
Generator
>
rng_gen
,
size_t
rng_elts_per_thread
)
{
size_t
rng_elts_per_thread
,
bool
return_max_logit
,
bool
cuda_graph
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
...
...
@@ -235,8 +236,9 @@ std::vector<py::object> fused_attn_fwd(
te_O
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
te_page_table_k
.
data
(),
te_page_table_v
.
data
(),
te_rng_state
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return_max_logit
,
cuda_graph
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// allocate memory for workspace and auxiliary output tensors
...
...
@@ -256,7 +258,9 @@ std::vector<py::object> fused_attn_fwd(
};
// allocate memory for nvte_aux_tensor_pack.tensors
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// f16_arbitrary:
// return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t
i
=
0
;
at
::
Tensor
output_tensor
;
...
...
@@ -265,8 +269,8 @@ std::vector<py::object> fused_attn_fwd(
allocateSpace
(
nvte_shape_to_vector
(
nvte_tensor_shape
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
static_cast
<
DType
>
(
nvte_tensor_type
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
false
);
set_tensor_param
(
i
++
,
output_tensor
);
// fp8 has an additional softmax stats tensor, ZInv
if
(
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
// fp8 has an additional softmax stats tensor, ZInv
; return_max_logit=true has an additional Sum_Exp tensor
if
(
return_max_logit
||
qkv_type
==
DType
::
kFloat8E4M3
||
qkv_type
==
DType
::
kFloat8E5M2
)
{
output_tensor
=
allocateSpace
(
nvte_shape_to_vector
(
nvte_tensor_shape
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
static_cast
<
DType
>
(
nvte_tensor_type
(
nvte_aux_tensor_pack
.
tensors
[
i
])),
false
);
...
...
@@ -292,8 +296,9 @@ std::vector<py::object> fused_attn_fwd(
te_O
.
data
(),
&
nvte_aux_tensor_pack
,
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
te_page_table_k
.
data
(),
te_page_table_v
.
data
(),
te_rng_state
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
return_max_logit
,
cuda_graph
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// destroy tensor wrappers, but not allocated memory
...
...
@@ -315,7 +320,7 @@ std::vector<py::object> fused_attn_bwd(
const
std
::
vector
<
at
::
Tensor
>
Aux_CTX_Tensors
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_padded
,
const
std
::
optional
<
at
::
Tensor
>
cu_seqlens_kv_padded
,
py
::
handle
s_quantizer
,
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
)
{
py
::
handle
dp_quantizer
,
py
::
handle
dqkv_quantizer
,
bool
cuda_graph
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
#else
...
...
@@ -533,13 +538,14 @@ std::vector<py::object> fused_attn_bwd(
// populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE
({
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_O
.
data
(),
te_dO
.
data
(),
te_S
.
data
(),
te_dP
.
data
(),
&
nvte_aux_tensor_pack
,
te_dQ
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
deterministic
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_O
.
data
(),
te_dO
.
data
(),
te_S
.
data
(),
te_dP
.
data
(),
&
nvte_aux_tensor_pack
,
te_dQ
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
deterministic
,
cuda_graph
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// allocate memory for workspace
...
...
@@ -549,13 +555,14 @@ std::vector<py::object> fused_attn_bwd(
// execute kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_O
.
data
(),
te_dO
.
data
(),
te_S
.
data
(),
te_dP
.
data
(),
&
nvte_aux_tensor_pack
,
te_dQ
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
deterministic
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
nvte_fused_attn_bwd
(
te_Q
.
data
(),
te_K
.
data
(),
te_V
.
data
(),
te_O
.
data
(),
te_dO
.
data
(),
te_S
.
data
(),
te_dP
.
data
(),
&
nvte_aux_tensor_pack
,
te_dQ
.
data
(),
te_dK
.
data
(),
te_dV
.
data
(),
te_dBias
.
data
(),
te_dSoftmaxOffset
.
data
(),
te_cu_seqlens_q
.
data
(),
te_cu_seqlens_kv
.
data
(),
te_cu_seqlens_q_padded
.
data
(),
te_cu_seqlens_kv_padded
.
data
(),
max_seqlen_q
,
max_seqlen_kv
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size
[
0
],
window_size
[
1
],
deterministic
,
cuda_graph
,
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// destroy tensor wrappers
...
...
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
c1a1c04e
...
...
@@ -491,6 +491,207 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
return
retval
;
}
// allocate fp4 data, fp8 scalings, and amax values
// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN]
// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate
std
::
tuple
<
std
::
vector
<
py
::
object
>
,
std
::
vector
<
TensorWrapper
>>
bulk_allocate_nvfp4_tensors
(
std
::
vector
<
std
::
vector
<
size_t
>>
&
shape_list
,
std
::
vector
<
py
::
handle
>
&
quantizer_py_list
,
std
::
vector
<
NVFP4Quantizer
*>
&
quantizer_cpp_list
)
{
init_extension
();
std
::
tuple
<
std
::
vector
<
py
::
object
>
,
std
::
vector
<
TensorWrapper
>>
retval
;
auto
&
tensor_py_list
=
std
::
get
<
0
>
(
retval
);
auto
&
tensor_cpp_list
=
std
::
get
<
1
>
(
retval
);
// Number of tensors
const
size_t
num_tensors
=
shape_list
.
size
();
if
(
num_tensors
==
0
)
{
return
retval
;
}
// Quantization parameters
const
auto
rowwise_usage
=
quantizer_cpp_list
[
0
]
->
rowwise_usage
;
const
auto
columnwise_usage
=
quantizer_cpp_list
[
0
]
->
columnwise_usage
;
const
auto
scaling_mode
=
quantizer_cpp_list
[
0
]
->
get_scaling_mode
();
const
auto
fp4_dtype
=
quantizer_cpp_list
[
0
]
->
dtype
;
constexpr
size_t
scale_elem_size
=
1
;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto
make_torch_view
=
[](
std
::
shared_ptr
<
at
::
Tensor
>
&
buffer
,
const
std
::
vector
<
size_t
>
&
shape
,
size_t
offset
,
at
::
ScalarType
dtype
)
->
at
::
Tensor
{
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
begin
(),
shape
.
end
());
bool
is_empty_shape
=
product
(
shape
)
==
0
;
if
(
buffer
->
data_ptr
<
uint8_t
>
()
==
nullptr
||
is_empty_shape
)
{
return
at
::
empty
(
shape_int64
,
at
::
device
(
at
::
kCUDA
).
dtype
(
dtype
));
}
return
at
::
from_blob
(
buffer
->
data_ptr
<
uint8_t
>
()
+
offset
,
shape_int64
,
[
buffer
](
void
*
)
{},
// deleter holds shared_ptr
at
::
device
(
at
::
kCUDA
).
dtype
(
dtype
));
};
// Lambda function for converting std::vector<size_t> shape to NVFP4 shape (last dim divided by 2)
auto
to_fp4_shape
=
[](
const
std
::
vector
<
size_t
>
&
shape
)
{
std
::
vector
<
size_t
>
fp4_shape
(
shape
.
begin
(),
shape
.
end
());
if
(
!
fp4_shape
.
empty
())
{
fp4_shape
.
back
()
/=
2
;
}
return
fp4_shape
;
};
// Allocate row-wise data
std
::
vector
<
at
::
Tensor
>
rowwise_data_list
,
rowwise_scale_list
,
amax_rowwise_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
rowwise_data_shapes
,
rowwise_scale_shapes
;
if
(
rowwise_usage
)
{
// Tensor sizes
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
rowwise_data_shapes
.
emplace_back
(
shape_list
[
i
]);
rowwise_scale_shapes
.
emplace_back
(
quantizer_cpp_list
[
i
]
->
get_scale_shape
(
shape_list
[
i
],
false
));
}
// Offsets in full buffer
size_t
buffer_size
=
0
;
std
::
vector
<
size_t
>
data_offsets
,
scale_offsets
,
amax_offsets
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
256
);
// align to 256B
data_offsets
.
push_back
(
buffer_size
);
// Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes).
// Integer arithmetic: ceil(product / 2) == (product + 1) / 2.
buffer_size
+=
(
product
(
rowwise_data_shapes
[
i
])
+
1
)
/
2
;
}
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
scale_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
rowwise_scale_shapes
[
i
])
*
scale_elem_size
;
}
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
amax_offsets
.
push_back
(
buffer_size
);
// amax is scalar in fp32, 4 bytes each
buffer_size
+=
4
;
}
// Allocate full buffer
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
rowwise_data_list
.
emplace_back
(
make_torch_view
(
buffer
,
to_fp4_shape
(
rowwise_data_shapes
[
i
]),
data_offsets
[
i
],
torch
::
kUInt8
));
rowwise_scale_list
.
emplace_back
(
make_torch_view
(
buffer
,
rowwise_scale_shapes
[
i
],
scale_offsets
[
i
],
torch
::
kUInt8
));
amax_rowwise_list
.
emplace_back
(
make_torch_view
(
buffer
,
std
::
vector
<
size_t
>
{
1
},
amax_offsets
[
i
],
torch
::
kUInt8
));
}
}
// Allocate column-wise data
std
::
vector
<
at
::
Tensor
>
columnwise_data_list
,
columnwise_scale_list
,
amax_columnwise_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
columnwise_data_shapes
,
columnwise_scale_shapes
;
if
(
columnwise_usage
)
{
// Tensor sizes
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
// push the transposed shape into NVFP4 columnwise shape
// NVFP4 on SM100 is TN only
columnwise_data_shapes
.
emplace_back
();
auto
&
shape
=
columnwise_data_shapes
.
back
();
shape
.
push_back
(
shape_list
[
i
].
back
());
for
(
size_t
j
=
0
;
j
<
shape_list
[
i
].
size
()
-
1
;
++
j
)
{
shape
.
push_back
(
shape_list
[
i
][
j
]);
}
columnwise_scale_shapes
.
emplace_back
(
quantizer_cpp_list
[
i
]
->
get_scale_shape
(
shape_list
[
i
],
true
));
}
// Offsets in full buffer
size_t
buffer_size
=
0
;
std
::
vector
<
size_t
>
data_offsets
,
scale_offsets
,
amax_offsets
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
256
);
// align to 256B
data_offsets
.
push_back
(
buffer_size
);
// Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes).
// Integer arithmetic: ceil(product / 2) == (product + 1) / 2.
buffer_size
+=
(
product
(
columnwise_data_shapes
[
i
])
+
1
)
/
2
;
}
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
scale_offsets
.
push_back
(
buffer_size
);
buffer_size
+=
product
(
columnwise_scale_shapes
[
i
])
*
scale_elem_size
;
}
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
buffer_size
=
roundup
(
buffer_size
,
16
);
// align to 16B
amax_offsets
.
push_back
(
buffer_size
);
// amax is scalar in fp32, 4 bytes each
buffer_size
+=
4
;
}
// Allocate full buffer
auto
buffer
=
std
::
make_shared
<
at
::
Tensor
>
(
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
)));
// Construct tensor views
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
columnwise_data_list
.
emplace_back
(
make_torch_view
(
buffer
,
to_fp4_shape
(
columnwise_data_shapes
[
i
]),
data_offsets
[
i
],
torch
::
kUInt8
));
columnwise_scale_list
.
emplace_back
(
make_torch_view
(
buffer
,
columnwise_scale_shapes
[
i
],
scale_offsets
[
i
],
torch
::
kUInt8
));
amax_columnwise_list
.
emplace_back
(
make_torch_view
(
buffer
,
std
::
vector
<
size_t
>
{
1
},
amax_offsets
[
i
],
torch
::
kUInt8
));
}
}
// Construct nvfp4 tensors
py
::
handle
NVFP4TensorClass
(
reinterpret_cast
<
PyObject
*>
(
NVFP4TensorStoragePythonClass
));
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
// Create tensor objects with proper reference counting
py
::
object
rowwise_data
=
rowwise_usage
?
py
::
cast
(
rowwise_data_list
[
i
])
:
py
::
none
();
py
::
object
rowwise_scale
=
rowwise_usage
?
py
::
cast
(
rowwise_scale_list
[
i
])
:
py
::
none
();
py
::
object
columnwise_data
=
(
columnwise_usage
?
py
::
cast
(
columnwise_data_list
[
i
])
:
py
::
none
());
py
::
object
columnwise_scale
=
(
columnwise_usage
?
py
::
cast
(
columnwise_scale_list
[
i
])
:
py
::
none
());
py
::
object
amax_rowwise
=
rowwise_usage
?
py
::
cast
(
amax_rowwise_list
[
i
])
:
py
::
none
();
py
::
object
amax_columnwise
=
columnwise_usage
?
py
::
cast
(
amax_columnwise_list
[
i
])
:
py
::
none
();
// Construct Python tensor
tensor_py_list
.
emplace_back
(
NVFP4TensorClass
(
rowwise_data
,
rowwise_scale
,
columnwise_data
,
columnwise_scale
,
amax_rowwise
,
amax_columnwise
,
fp4_dtype
,
quantizer_py_list
[
i
]));
// Construct C++ tensor
// Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor,
// then set the amax and amax_columnwise values.
{
auto
tensor_wrapper
=
makeTransformerEngineTensor
(
rowwise_usage
?
rowwise_data_list
[
i
].
data_ptr
()
:
nullptr
,
columnwise_usage
?
columnwise_data_list
[
i
].
data_ptr
()
:
nullptr
,
rowwise_usage
?
rowwise_data_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
columnwise_usage
?
columnwise_data_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
fp4_dtype
,
/*amax_ptr=*/
nullptr
,
/*scale_ptr=*/
nullptr
,
rowwise_usage
?
rowwise_scale_list
[
i
].
data_ptr
()
:
nullptr
,
columnwise_usage
?
columnwise_scale_list
[
i
].
data_ptr
()
:
nullptr
,
rowwise_usage
?
rowwise_scale_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
columnwise_usage
?
columnwise_scale_shapes
[
i
]
:
std
::
vector
<
size_t
>
{},
scaling_mode
);
// Set the amax rowwise and amax columnwise if available
if
(
rowwise_usage
)
{
tensor_wrapper
.
set_amax
(
amax_rowwise_list
[
i
].
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
if
(
columnwise_usage
)
{
tensor_wrapper
.
set_columnwise_amax
(
amax_columnwise_list
[
i
].
data_ptr
(),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
tensor_cpp_list
.
emplace_back
(
std
::
move
(
tensor_wrapper
));
}
}
return
retval
;
}
}
// namespace
std
::
vector
<
py
::
object
>
split_quantize
(
const
at
::
Tensor
&
tensor
,
...
...
@@ -549,7 +750,8 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
bool
use_fused_bulk_alloc
=
true
;
for
(
size_t
i
=
0
;
i
<
quantizer_list
.
size
();
i
++
)
{
if
(
!
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer_list
[
i
].
ptr
())
&&
!
detail
::
IsMXFP8Quantizers
(
quantizer_list
[
i
].
ptr
()))
{
!
detail
::
IsMXFP8Quantizers
(
quantizer_list
[
i
].
ptr
())
&&
!
detail
::
IsNVFP4Quantizers
(
quantizer_list
[
i
].
ptr
()))
{
use_fused_bulk_alloc
=
false
;
break
;
}
...
...
@@ -570,6 +772,7 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
// TODO(zhongbo): make a better api to make this part less hacky
bool
is_fp8_blockwise
=
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer_list
[
0
].
ptr
());
bool
is_mxfp8
=
detail
::
IsMXFP8Quantizers
(
quantizer_list
[
0
].
ptr
());
bool
is_nvfp4
=
detail
::
IsNVFP4Quantizers
(
quantizer_list
[
0
].
ptr
());
if
(
is_fp8_blockwise
)
{
// FP8 block-scaling: construct output tensors with bulk allocations
std
::
vector
<
Float8BlockQuantizer
*>
blockwise_quantizers
;
...
...
@@ -586,6 +789,14 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
}
std
::
tie
(
output_py_list
,
output_cpp_list
)
=
bulk_allocate_mxfp8_tensors
(
split_shapes
,
quantizer_list
,
mxfp8_quantizers
);
}
else
if
(
is_nvfp4
)
{
// NVFP4: construct output tensors with bulk allocations
std
::
vector
<
NVFP4Quantizer
*>
nvfp4_quantizers
;
for
(
auto
&
quantizer
:
quantizer_cpp_list
)
{
nvfp4_quantizers
.
push_back
(
static_cast
<
NVFP4Quantizer
*>
(
quantizer
.
get
()));
}
std
::
tie
(
output_py_list
,
output_cpp_list
)
=
bulk_allocate_nvfp4_tensors
(
split_shapes
,
quantizer_list
,
nvfp4_quantizers
);
}
else
{
NVTE_CHECK
(
false
,
"Expected either FP8 block-scaling or MXFP8 quantizer"
);
}
...
...
transformer_engine/pytorch/csrc/extensions/recipe.cpp
View file @
c1a1c04e
...
...
@@ -20,10 +20,11 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK
(
amax
.
scalar_type
()
==
at
::
kFloat
,
"amax must be a float tensor"
);
TORCH_CHECK
(
amax
.
numel
()
==
1
,
"amax must have exactly one element"
);
auto
*
amax_ptr
=
amax
.
data_ptr
<
float
>
();
TensorWrapper
fake_te_output
(
nullptr
,
te_input
.
shape
(),
DType
::
kFloat8E4M3
,
// It doesn't matter because we only compute amax.
amax
.
data_ptr
<
float
>
()
);
amax
_ptr
);
nvte_compute_amax
(
te_input
.
data
(),
fake_te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
...
...
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
c1a1c04e
...
...
@@ -1200,6 +1200,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
rowwise_scale_inv_shape
.
end
());
rowwise_data_tensor
=
at
::
empty
(
convert_shape_for_fp4
(
shape_int64
),
bit8_tensor_opts
);
rowwise_scale_inv_tensor
=
at
::
empty
(
scale_inv_shape_int64
,
bit8_tensor_opts
);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_rowwise
=
at
::
empty
({
1
},
bit32_tensor_opts
);
}
if
(
columnwise_usage
)
{
...
...
@@ -1213,6 +1215,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
columnwise_data_tensor
=
at
::
empty
(
convert_shape_for_fp4
(
transpose_shape_int64
),
bit8_tensor_opts
);
columnwise_scale_inv_tensor
=
at
::
empty
(
scale_inv_shape_int64
,
bit8_tensor_opts
);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_columnwise
=
at
::
empty
({
1
},
bit32_tensor_opts
);
}
...
...
@@ -1352,6 +1356,8 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
}
if
(
!
amax_rowwise
)
{
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_rowwise
=
at
::
empty
({
1
},
opts
);
tensor
.
attr
(
"_amax_rowwise"
)
=
*
amax_rowwise
;
}
...
...
@@ -1392,7 +1398,9 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::convert_and_update_tensor(
}
if
(
!
amax_columnwise
)
{
const
auto
opts
=
at
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
amax_columnwise
=
at
::
zeros
({
1
},
opts
);
// hadamard amax kernel will zero out pointer with ZeroAmaxKernel
// nvte_compute_amax_with_config will zero out the pointer if needed
amax_columnwise
=
at
::
empty
({
1
},
opts
);
tensor
.
attr
(
"_amax_columnwise"
)
=
*
amax_columnwise
;
}
}
else
{
// columnwise_usage == false
...
...
transformer_engine/pytorch/csrc/util.cpp
View file @
c1a1c04e
...
...
@@ -50,8 +50,6 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
void
*
scale_inv_dptr
=
scale_inv
.
data_ptr
;
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
swizzled_scale_inv
,
0
);
// Reconstruct input only to avoid swizzling both directions if not needed.
// The specific dtype used is irrelevant, just needs to be correct bits.
transformer_engine
::
TensorWrapper
input_cu
(
input
.
scaling_mode
());
transformer_engine
::
TensorWrapper
output_cu
(
input
.
scaling_mode
());
...
...
@@ -100,10 +98,14 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
if
(
tensors
.
front
().
scaling_mode
()
==
NVTE_INVALID_SCALING
)
{
NVTE_ERROR
(
"Invalid scaling mode for swizzle."
);
}
else
if
(
tensors
.
front
().
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
)
{
}
else
if
(
tensors
.
front
().
scaling_mode
()
!=
NVTE_MXFP8_1D_SCALING
&&
tensors
.
front
().
scaling_mode
()
!=
NVTE_NVFP4_1D_SCALING
)
{
return
std
::
nullopt
;
}
const
auto
scaling_mode
=
tensors
.
front
().
scaling_mode
();
const
auto
nvfp4
=
scaling_mode
==
NVTE_NVFP4_1D_SCALING
;
std
::
vector
<
transformer_engine
::
TensorWrapper
>
wrappers
;
std
::
vector
<
NVTETensor
>
input_tensors
,
output_tensors
;
...
...
@@ -131,39 +133,44 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
// Allocate full buffer
auto
buffer
=
at
::
empty
({(
int64_t
)
buffer_size
},
at
::
device
(
at
::
kCUDA
).
dtype
(
torch
::
kUInt8
));
const
auto
input_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat4E2M1
:
transformer_engine
::
DType
::
kFloat8E4M3
;
const
auto
scale_inv_dtype
=
(
nvfp4
)
?
transformer_engine
::
DType
::
kFloat8E4M3
:
transformer_engine
::
DType
::
kFloat8E8M0
;
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
auto
&
tensor
=
tensors
[
i
];
void
*
scale_inv_dptr
=
scale_inv_dptrs
[
i
];
void
*
swizzled_scale_inv_dptr
=
getDataPtr
(
buffer
,
scale_inv_offsets
[
i
]);
auto
input_shape
=
nvte_shape_to_vector
(
tensor
.
shape
());
// auto input_shape = nvte_shape_to_vector(tensor.shape());
NVTEShape
nvte_input_shape
;
if
(
rowwise
)
{
nvte_input_shape
=
tensor
.
shape
();
}
else
{
nvte_input_shape
=
tensor
.
get_columnwise_data
().
shape
;
}
auto
input_shape
=
nvte_shape_to_vector
(
nvte_input_shape
);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine
::
TensorWrapper
input_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
output_cu
(
NVTE_MXFP8_1D_SCALING
);
transformer_engine
::
TensorWrapper
input_cu
(
scaling_mode
);
transformer_engine
::
TensorWrapper
output_cu
(
scaling_mode
);
if
(
rowwise
)
{
input_cu
.
set_rowwise_data
(
tensor
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shapes
[
i
]);
output_cu
.
set_rowwise_data
(
tensor
.
dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shapes
[
i
]);
input_cu
.
set_rowwise_data
(
tensor
.
dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_rowwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
output_cu
.
set_rowwise_data
(
tensor
.
dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
// Set the swizzled scaling factor to the original tensor.
tensor
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shapes
[
i
]);
tensor
.
set_rowwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
}
else
{
input_cu
.
set_columnwise_data
(
tensor
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shapes
[
i
]);
output_cu
.
set_columnwise_data
(
tensor
.
columnwise_dptr
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shapes
[
i
]);
input_cu
.
set_columnwise_data
(
tensor
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
input_cu
.
set_columnwise_scale_inv
(
scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
output_cu
.
set_columnwise_data
(
tensor
.
columnwise_dptr
(),
input_dtype
,
input_shape
);
output_cu
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
// Set the swizzled scaling factor to the original tensor.
tensor
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
transformer_engine
::
DType
::
kFloat8E8M0
,
scale_inv_shapes
[
i
]);
tensor
.
set_columnwise_scale_inv
(
swizzled_scale_inv_dptr
,
scale_inv_dtype
,
scale_inv_shapes
[
i
]);
}
input_tensors
.
emplace_back
(
input_cu
.
data
());
...
...
transformer_engine/pytorch/
experimental
/__init__.py
→
transformer_engine/pytorch/
custom_recipes
/__init__.py
View file @
c1a1c04e
File moved
transformer_engine/pytorch/
experimental
/gemm.py
→
transformer_engine/pytorch/
custom_recipes
/gemm.py
View file @
c1a1c04e
...
...
@@ -2,21 +2,21 @@
#
# See LICENSE for license information.
"""GEMM API
for experimental middleware between Transformer Engine and Kitchen
."""
"""GEMM API
that enables custom GEMM logic for custom quantization recipes
."""
from
typing
import
Iterable
,
Optional
import
torch
from
transformer_engine.pytorch.
experimental
.quantization
import
(
from
transformer_engine.pytorch.
custom_recipes
.quantization
import
(
MMParams
,
GEMMType
,
)
from
transformer_engine.pytorch.
tensor.
quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
transformer_engine.pytorch.tensor.utils
import
is_
experimental
from
transformer_engine.pytorch.quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
transformer_engine.pytorch.tensor.utils
import
is_
custom
def
experimental
_gemm
(
def
custom
_gemm
(
A
:
QuantizedTensorStorage
,
B
:
QuantizedTensorStorage
,
workspace
:
torch
.
Tensor
,
# pylint: disable=unused-argument
...
...
@@ -32,7 +32,7 @@ def experimental_gemm(
grad
:
bool
=
False
,
)
->
Iterable
[
Optional
[
torch
.
Tensor
]]:
"""Dispatch GEMM to quantizer's qgemm method."""
assert
is_
experimental
(
A
)
and
is_experimental
(
B
),
"A and B must be
experimental
tensors"
assert
is_
custom
(
A
)
and
is_custom
(
B
),
"A and B must be
custom
tensors"
A
,
B
=
B
,
A
...
...
transformer_engine/pytorch/
experimental
/quantization.py
→
transformer_engine/pytorch/
custom_recipes
/quantization.py
View file @
c1a1c04e
File moved
transformer_engine/pytorch/
experimental
/quantization_nvfp4.py
→
transformer_engine/pytorch/
custom_recipes
/quantization_nvfp4.py
View file @
c1a1c04e
...
...
@@ -9,9 +9,9 @@ from typing import Optional, Tuple, Union
import
torch
from
transformer_engine.pytorch.
experimental
import
quantization
from
transformer_engine.pytorch.
experimental
import
utils
from
transformer_engine.pytorch.
tensor.
quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
transformer_engine.pytorch.
custom_recipes
import
quantization
from
transformer_engine.pytorch.
custom_recipes
import
utils
from
transformer_engine.pytorch.quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
def
nvfp4_ref_rht_2d_quantizer_factory
(
role
):
...
...
@@ -229,8 +229,8 @@ class NVFP4TensorRef(QuantizedTensorStorage):
_quantizer
:
Optional
[
Quantizer
]
=
None
@
property
def
experimental
(
self
)
->
bool
:
"""Flag to indicate this quantizer is us
ing experimental Kitchen middleware
."""
def
custom
(
self
)
->
bool
:
"""Flag to indicate this quantize
d tenso
r is
c
us
tom
."""
return
True
def
prepare_for_saving
(
...
...
@@ -362,8 +362,8 @@ class NVFP4QuantizerRef(Quantizer):
self
.
with_random_sign_mask
=
with_random_sign_mask
@
property
def
experimental
(
self
)
->
bool
:
"""Flag to indicate this quantizer is us
ing experimental Kitchen middleware
"""
def
custom
(
self
)
->
bool
:
"""Flag to indicate this quantizer is
c
us
tom.
"""
return
True
@
staticmethod
...
...
transformer_engine/pytorch/
experimental
/utils.py
→
transformer_engine/pytorch/
custom_recipes
/utils.py
View file @
c1a1c04e
File moved
transformer_engine/pytorch/distributed.py
View file @
c1a1c04e
...
...
@@ -29,24 +29,25 @@ except ImportError:
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.triton.pad
import
pad_columnwise_scale_inv
from
.
import
torch_version
from
.utils
import
(
is_non_tn_fp8_gemm_supported
,
safely_set_viewless_tensor_data
,
needs_quantized_gemm
,
)
from
.constants
import
dist_group_type
from
.quantization
import
FP8GlobalStateManager
,
autocast
from
.tensor.float8_tensor
import
Float8Quantizer
,
Float8Tensor
,
Float8CurrentScalingQuantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.tensor.nvfp4_tensor
import
NVFP4Quantizer
from
.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
.tensor
.quantized_tensor
import
QuantizedTensorStorage
,
QuantizedTensor
,
Quantizer
from
.quantized_tensor
import
QuantizedTensorStorage
,
QuantizedTensor
,
Quantizer
from
.tensor.storage.float8_tensor_storage
import
Float8TensorStorage
from
.tensor.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
.tensor.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
from
.tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
.triton.pad
import
pad_columnwise_scale_inv
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
,
DebugQuantizer
...
...
@@ -1889,6 +1890,43 @@ def allreduce(
return
inp
,
handle
def
_get_module_fsdp_state
(
module
):
"""
If module is an FSDP module, return its _FSDPState.
Otherwise, return the _FSDPState of the closest parent FSDP module
in the module hierarchy the module belongs to.
"""
if
hasattr
(
module
,
"_get_fsdp_state"
):
# this will return correct fsdp state if module itself is an fsdp module
fsdp_state
=
module
.
_get_fsdp_state
()
elif
getattr
(
module
,
"_te_cached_parent_fsdp_state"
,
None
)
is
not
None
:
# See if we have cached the parent fsdp state of the module
fsdp_state
=
module
.
_te_cached_parent_fsdp_state
else
:
from
torch.distributed._composable_state
import
_module_state_mapping
# Otherwise get the fsdp state of lca of module in the module hierarchy
min_nodes_in_parent
=
float
(
"inf"
)
closest_parent_fsdp_mod
=
None
for
fsdp_mod
in
_module_state_mapping
.
keys
():
all_submodules
=
list
(
fsdp_mod
.
modules
())
for
submodule
in
all_submodules
:
if
submodule
is
module
:
if
min_nodes_in_parent
>
len
(
all_submodules
):
closest_parent_fsdp_mod
=
fsdp_mod
min_nodes_in_parent
=
len
(
all_submodules
)
if
closest_parent_fsdp_mod
is
None
:
raise
RuntimeError
(
"Module is not FSDP-wrapped and does not have any FSDP-wrapped parent modules."
)
fsdp_state
=
closest_parent_fsdp_mod
.
_get_fsdp_state
()
# Cache the parent fsdp state of the module to avoid recomputing
# the closest parent fsdp module.
module
.
_te_cached_parent_fsdp_state
=
fsdp_state
return
fsdp_state
def
_fsdp_scatter_tensors
(
fsdp_group
:
dist_group_type
,
*
tensors
:
torch
.
Tensor
,
...
...
transformer_engine/pytorch/graph.py
View file @
c1a1c04e
...
...
@@ -322,14 +322,16 @@ def _make_graphed_callables(
fwd_graphs
=
[
torch
.
cuda
.
CUDAGraph
()
for
_
in
range
(
len
(
flatten_sample_args
))]
bwd_graphs
=
[
torch
.
cuda
.
CUDAGraph
()
for
_
in
range
(
len
(
flatten_sample_args
))]
bwd_dw_graphs
=
[
torch
.
cuda
.
CUDAGraph
()
for
_
in
range
(
len
(
flatten_sample_args
))]
graph_callables
=
[
None
for
_
in
range
(
len
(
flatten_sample_args
))]
# For cases with multiple active RNG states, e.g. TP.
if
graph_safe_rng_available
():
for
_
,
state
in
get_all_rng_states
().
items
():
for
fwd_graph
,
bwd_graph
in
zip
(
fwd_graphs
,
bwd_graphs
):
for
fwd_graph
,
bwd_graph
,
bwd_dw_graph
in
zip
(
fwd_graphs
,
bwd_graphs
,
bwd_dw_graphs
):
fwd_graph
.
register_generator_state
(
state
)
bwd_graph
.
register_generator_state
(
state
)
bwd_dw_graph
.
register_generator_state
(
state
)
mempool
=
graph_pool_handle
()
if
pool
is
None
else
pool
...
...
@@ -366,21 +368,8 @@ def _make_graphed_callables(
),
f
"Warmup runs
{
len
(
warmup_func
)
}
but only
{
len
(
set
(
warmup_func_idx
))
}
are unique."
# Filter the TE modules that cudagraph can access.
visited_te_modules
=
set
()
def
hook_fn
(
module
,
inputs
,
outputs
):
# pylint: disable=unused-argument
if
isinstance
(
module
,
TransformerEngineBaseModule
):
visited_te_modules
.
add
(
module
)
# If forward is called on a BasicOperation directly the hook will run
elif
isinstance
(
module
,
BasicOperation
):
visited_te_modules
.
add
(
module
)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif
isinstance
(
module
,
Sequential
):
assert
module
.
_module_groups
is
not
None
,
"Should have been initialized by warmup"
for
module_group
in
module
.
_module_groups
:
if
isinstance
(
module_group
,
OperationFuser
):
for
basic_op
in
module_group
.
_basic_ops
:
visited_te_modules
.
add
(
basic_op
)
visited_te_modules
=
{}
need_bwd_dw_graph
=
{}
# Run warmup and do the above filtering.
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
Stream
()):
...
...
@@ -388,6 +377,31 @@ def _make_graphed_callables(
args
=
sample_args
[
func_idx
]
kwargs
=
sample_kwargs
[
func_idx
]
static_input_surface
=
per_callable_static_input_surfaces
[
func_idx
]
def
hook_fn
(
module
,
inputs
,
outputs
,
func_idx
=
func_idx
):
# pylint: disable=unused-argument
modules
=
set
()
if
isinstance
(
module
,
TransformerEngineBaseModule
):
modules
.
add
(
module
)
# If forward is called on a BasicOperation directly the hook will run
elif
isinstance
(
module
,
BasicOperation
):
modules
.
add
(
module
)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif
isinstance
(
module
,
Sequential
):
assert
(
module
.
_module_groups
is
not
None
),
"Should have been initialized by warmup"
for
module_group
in
module
.
_module_groups
:
if
isinstance
(
module_group
,
OperationFuser
):
for
basic_op
in
module_group
.
_basic_ops
:
modules
.
add
(
basic_op
)
if
modules
:
if
func_idx
not
in
visited_te_modules
:
visited_te_modules
[
func_idx
]
=
modules
else
:
visited_te_modules
[
func_idx
].
update
(
modules
)
for
warmup_iter
in
range
(
num_warmup_iters
):
hooks
=
[]
for
module
in
func
.
modules
():
...
...
@@ -432,6 +446,15 @@ def _make_graphed_callables(
module_params_with_grad
)
per_callable_static_input_surfaces
[
func_idx
]
=
static_input_surface
# Run wgrad. This is essential for some TE modules when they have
# delay_wgrad_compute enabled.
need_backward_dw
=
False
for
module
in
visited_te_modules
.
get
(
func_idx
,
set
()):
if
hasattr
(
module
,
"need_backward_dw"
)
and
module
.
need_backward_dw
():
need_backward_dw
=
True
module
.
backward_dw
()
need_bwd_dw_graph
[
func_idx
]
=
need_backward_dw
else
:
grad_inputs
=
None
del
outputs
,
grad_inputs
...
...
@@ -514,6 +537,17 @@ def _make_graphed_callables(
allow_unused
=
allow_unused_input
,
retain_graph
=
retain_graph_in_backward
,
)
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it.
if
need_bwd_dw_graph
[
per_callable_bwd_idx
]:
bwd_dw_graph
=
bwd_dw_graphs
[
per_callable_bwd_idx
]
with
_graph_context_wrapper
(
bwd_dw_graph
,
pool
=
mempool
):
for
module
in
visited_te_modules
[
per_callable_bwd_idx
]:
if
(
hasattr
(
module
,
"need_backward_dw"
)
and
module
.
need_backward_dw
()
):
module
.
backward_dw
()
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
...
...
@@ -582,10 +616,12 @@ def _make_graphed_callables(
# Capture backward graphs in reverse order
per_callable_static_grad_outputs
=
[]
per_callable_static_grad_inputs
=
[]
for
static_input_surface
,
static_outputs
,
bwd_graph
in
zip
(
for
static_input_surface
,
static_outputs
,
bwd_graph
,
bwd_dw_graph
,
bwd_idx
in
zip
(
reversed
(
per_callable_static_input_surfaces
),
reversed
(
per_callable_static_outputs
),
reversed
(
bwd_graphs
),
reversed
(
bwd_dw_graphs
),
reversed
(
range
(
len
(
per_callable_static_input_surfaces
))),
):
# For now, assumes all static_outputs require grad
static_grad_outputs
=
tuple
(
...
...
@@ -601,6 +637,11 @@ def _make_graphed_callables(
allow_unused
=
allow_unused_input
,
retain_graph
=
retain_graph_in_backward
,
)
if
need_bwd_dw_graph
[
bwd_idx
]:
with
_graph_context_wrapper
(
bwd_dw_graph
,
pool
=
mempool
):
for
module
in
visited_te_modules
[
bwd_idx
]:
if
hasattr
(
module
,
"need_backward_dw"
)
and
module
.
need_backward_dw
():
module
.
backward_dw
()
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
# don't require grad. I couldn't think of a slick one-liner for this pattern.
...
...
@@ -715,6 +756,21 @@ def _make_graphed_callables(
return
functionalized
def
make_graphed_attribute_functions
(
graph_idx
):
# Attach backward_dw as an attribute to the graphed callable.
def
backward_dw
():
if
need_bwd_dw_graph
.
get
(
graph_idx
,
False
):
bwd_dw_graphs
[
graph_idx
].
replay
()
# Attach reset as an attribute to the graphed callable.
def
reset
():
fwd_graphs
[
graph_idx
].
reset
()
bwd_graphs
[
graph_idx
].
reset
()
bwd_dw_graphs
[
graph_idx
].
reset
()
return
backward_dw
,
reset
# Put together the final graphed callables
ret
=
[]
for
i
in
range
(
len
(
sample_args
)):
...
...
@@ -732,9 +788,10 @@ def _make_graphed_callables(
)
func
=
graph_callables
[
i
]
te_modules
=
visited_te_modules
.
get
(
i
,
set
())
if
isinstance
(
func
,
torch
.
nn
.
Module
):
def
make_graphed_forward
(
func
,
graph_training_state
,
graphed
,
orig_fwd
):
def
make_graphed_forward
(
func
,
graph_training_state
,
graphed
,
orig_fwd
,
te_modules
):
def
new_fwd
(
*
user_args
,
**
user_kwargs
):
# If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method
...
...
@@ -743,7 +800,7 @@ def _make_graphed_callables(
if
FP8GlobalStateManager
.
is_fp8_enabled
():
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
for
m
in
func
.
modules
():
if
m
not
in
visited_
te_modules
:
if
m
not
in
te_modules
:
# Only Set the FP8 meta for the modules included by forward
continue
if
isinstance
(
m
,
TransformerEngineBaseModule
):
...
...
@@ -780,7 +837,7 @@ def _make_graphed_callables(
return
new_fwd
forward
=
make_graphed_forward
(
func
,
func
.
training
,
graphed
,
func
.
forward
)
forward
=
make_graphed_forward
(
func
,
func
.
training
,
graphed
,
func
.
forward
,
te_modules
)
if
_order
is
None
:
func
.
forward
=
forward
ret
.
append
(
func
)
...
...
@@ -789,6 +846,10 @@ def _make_graphed_callables(
else
:
ret
.
append
(
graphed
)
backward_dw_func
,
reset_func
=
make_graphed_attribute_functions
(
i
)
setattr
(
ret
[
-
1
],
"backward_dw"
,
backward_dw_func
)
setattr
(
ret
[
-
1
],
"reset"
,
reset_func
)
if
just_one_callable
:
return
ret
[
0
]
...
...
transformer_engine/pytorch/module/base.py
View file @
c1a1c04e
...
...
@@ -17,6 +17,7 @@ from types import MethodType
import
torch
import
torch.nn.functional
as
F
from
torch.distributed.tensor
import
DTensor
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
...
...
@@ -38,7 +39,7 @@ from ..distributed import (
_fsdp_gather_tensors
,
)
from
..constants
import
dist_group_type
from
..
tensor.
quantized_tensor
import
QuantizedTensor
,
QuantizedTensorStorage
,
Quantizer
from
..quantized_tensor
import
QuantizedTensor
,
QuantizedTensorStorage
,
Quantizer
from
..tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
...
...
@@ -707,6 +708,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
_fp8_workspaces
:
Dict
[
str
,
QuantizedTensor
]
=
{}
self
.
activation_dtype
:
Optional
[
torch
.
dtype
]
=
None
self
.
wgrad_accumulation_and_reduce_hooks
=
[]
self
.
wgrad_store
=
None
if
not
TEDebugState
.
debug_enabled
:
TEDebugState
.
initialize
()
...
...
@@ -1288,7 +1290,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
metedata used in deferred initialization.
"""
super
().
register_parameter
(
name
,
param
)
self
.
param_init_meta
[
name
]
=
_ParameterInitMeta
(
**
kwargs
)
# Initialize param_init_meta exactly once during the init. FSDP2 can call
# register parameter again to change parameters to DTensors. And it calls
# it without custom fp8 specific kwargs that we need. And so we dont want
# to reset/loose our fp8 init attributes.
if
hasattr
(
self
,
"param_init_meta"
)
and
name
not
in
self
.
param_init_meta
:
self
.
param_init_meta
[
name
]
=
_ParameterInitMeta
(
**
kwargs
)
def
reset_parameters
(
self
,
defer_init
:
Optional
[
bool
]
=
False
)
->
None
:
"""
...
...
@@ -1300,10 +1307,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return
for
name
,
param
in
self
.
named_parameters
(
recurse
=
False
):
# Check if parameter is a DTensor (FSDP2) or regular tensor
is_dtensor
=
isinstance
(
param
,
DTensor
)
dtensor_param
=
param
if
is_dtensor
else
None
# Need to update/quantize local tensor in case of DTensor
param
=
param
.
_local_tensor
if
is_dtensor
else
param
# Ensure parameter is on a real device
if
param
.
device
==
torch
.
device
(
"meta"
):
param
=
torch
.
empty_like
(
param
,
device
=
"cuda"
)
# Initialize the parameter values on device
init_fn
=
self
.
param_init_meta
[
name
].
init_fn
get_rng_state_tracker
=
self
.
param_init_meta
[
name
].
get_rng_state_tracker
...
...
@@ -1332,7 +1343,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise
RuntimeError
(
"Weight quantizer has not been initialized"
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
torch
.
is_grad_enabled
())
quantizer
.
internal
=
False
if
is_dtensor
and
isinstance
(
quantizer
,
Float8CurrentScalingQuantizer
):
device_mesh
=
dtensor_param
.
device_mesh
amax_reduction_group
=
(
device_mesh
.
get_group
(
mesh_dim
=
"shard"
)
if
device_mesh
.
ndim
>
1
else
device_mesh
.
get_group
()
)
quantizer
.
amax_reduction_group
=
amax_reduction_group
quantizer
.
with_amax_reduction
=
True
# Quantize parameter
param
=
quantizer
(
param
)
...
...
@@ -1340,7 +1359,18 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
param
=
torch
.
nn
.
Parameter
(
param
)
if
is_dtensor
:
# recreate the DTensor from the parameter.
dtensor_param
=
DTensor
.
from_local
(
param
,
device_mesh
=
dtensor_param
.
device_mesh
,
placements
=
dtensor_param
.
placements
,
shape
=
dtensor_param
.
size
(),
stride
=
dtensor_param
.
stride
(),
)
dtensor_param
=
torch
.
nn
.
Parameter
(
dtensor_param
)
else
:
param
=
torch
.
nn
.
Parameter
(
param
)
# Keep high-precision values on CPU if needed
if
high_precision_init_val
is
not
None
:
...
...
@@ -1368,8 +1398,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
param
.
_high_precision_init_val
=
high_precision_init_val
param
.
get_high_precision_init_val
=
MethodType
(
get
,
param
)
param
.
clear_high_precision_init_val
=
MethodType
(
clear
,
param
)
# Update the parameter based on its type
setattr
(
self
,
name
,
param
)
if
not
is_dtensor
:
setattr
(
self
,
name
,
param
)
else
:
setattr
(
self
,
name
,
dtensor_param
)
@
abstractmethod
def
forward
(
self
):
...
...
@@ -1526,12 +1560,21 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""
self
.
wgrad_accumulation_and_reduce_hooks
.
append
(
wgrad_accumulation_and_reduce_hook
)
def
need_backward_dw
(
self
):
"""
Check if this module needs to execute the delayed weight gradient computation.
This method should be used at the beginning of self.backward_dw() to determine if it
should actually be executed or just return without doing anything.
User can also manually call this method to check that before calling into backward_dw().
"""
return
self
.
wgrad_store
is
not
None
and
self
.
wgrad_store
.
delay_wgrad_compute
()
def
backward_dw
(
self
):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
if
not
self
.
need_backward_dw
():
return
with
torch
.
cuda
.
nvtx
.
range
(
f
"_
{
self
.
__class__
.
__name__
}
_wgrad"
):
(
wgrad
,
bgrad
),
_
=
self
.
wgrad_store
.
pop
()
...
...
@@ -1568,7 +1611,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
debug
=
False
else
:
debug
=
TEDebugState
.
get_iteration
()
>=
self
.
next_iter_when_debug_should_be_run
self
.
debug_last_iteration
=
TEDebugState
.
get_iteration
()
self
.
debug_last_iteration
=
TEDebugState
.
get_iteration
()
self
.
debug_enabled_in_this_iteration
=
debug
else
:
# If this is the same iteration as previous invocation of the module,
# we use the debug value from the first invocation in the iteration.
debug
=
self
.
debug_enabled_in_this_iteration
return
debug
def
no_debug_features_active
(
self
,
quantizers
):
...
...
transformer_engine/pytorch/module/fp8_padding.py
View file @
c1a1c04e
...
...
@@ -78,7 +78,7 @@ class Fp8Padding(torch.nn.Module):
number of GEMMs to be performed simultaneously.
align_size : int, optional
the alignment size for the input tensor. If not provided, the alignment size will
be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first
be determined by the FP8
/FP4
recipe (32 for MXFP8
/NVFP4
and 16 for others) in the first
forward pass.
"""
...
...
@@ -111,7 +111,14 @@ class Fp8Padding(torch.nn.Module):
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
if
self
.
align_size
is
None
:
self
.
align_size
=
32
if
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
else
16
self
.
align_size
=
(
32
if
(
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
or
FP8GlobalStateManager
.
get_fp8_recipe
().
nvfp4
()
)
else
16
)
# FP8 padding calculate
padded_m_splits
=
[
...
...
transformer_engine/pytorch/module/fp8_unpadding.py
View file @
c1a1c04e
...
...
@@ -75,9 +75,9 @@ class Fp8Unpadding(torch.nn.Module):
num_gemms : int
number of GEMMs to be performed simultaneously.
align_size : int, optional
t
he alignment size for the input tensor. If not provided, the alignment size will
be determined b
y
the FP8 recipe
(32 for MXFP8 and 16 for others) in the first
forward pass
.
T
he alignment size for the input tensor. If not provided, the alignment size will
be
automatically
determined b
ased on
the FP8
/FP4
recipe
in the first forward pass:
32 for MXFP8 or NVFP4, otherwise 16
.
"""
def
__init__
(
...
...
@@ -109,7 +109,14 @@ class Fp8Unpadding(torch.nn.Module):
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
if
self
.
align_size
is
None
:
self
.
align_size
=
32
if
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
else
16
self
.
align_size
=
(
32
if
(
FP8GlobalStateManager
.
get_fp8_recipe
().
mxfp8
()
or
FP8GlobalStateManager
.
get_fp8_recipe
().
nvfp4
()
)
else
16
)
# FP8 padding calculate
padded_m_splits
=
[
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
c1a1c04e
...
...
@@ -14,6 +14,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
.base
import
(
get_dummy_wgrad
,
get_multi_stream_cublas_workspace
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
...
...
@@ -42,10 +43,10 @@ from ..cpp_extensions import (
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..cpu_offload
import
is_cpu_offload_enabled
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_not_offload
,
start_offload
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..
tensor.
quantized_tensor
import
(
from
..quantized_tensor
import
(
QuantizedTensorStorage
,
Quantizer
,
prepare_for_saving
,
...
...
@@ -111,9 +112,15 @@ class _GroupedLinear(torch.autograd.Function):
is_fp8_activation_recompute_enabled
()
and
not
in_fp8_activation_recompute_phase
()
)
if
weight_quantizers
[
0
]
is
not
None
:
# No need to set the quantizer states if weight is already quantized
if
weight_quantizers
[
0
]
is
not
None
and
not
isinstance
(
weights
[
0
],
QuantizedTensorStorage
):
for
weight_quantizer
in
weight_quantizers
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
elif
isinstance
(
weights
[
0
],
QuantizedTensorStorage
):
# If weights are already quantized, no need to set quantizer states
weight_quantizers
=
[
weight
.
_quantizer
for
weight
in
weights
]
if
output_quantizers
[
0
]
is
not
None
:
for
output_quantizer
in
output_quantizers
:
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
...
...
@@ -132,6 +139,9 @@ class _GroupedLinear(torch.autograd.Function):
else
:
inputmats
=
torch
.
split
(
cast_if_needed
(
inp_view
,
activation_dtype
),
m_splits
)
if
cpu_offloading
:
start_offload
(
*
inputmats
)
# Initialize weights
weights_fp8
:
list
if
fp8
:
...
...
@@ -193,6 +203,9 @@ class _GroupedLinear(torch.autograd.Function):
for
i
in
range
(
num_gemms
):
weight_quantizers
[
i
].
calibrate
(
weights
[
i
])
if
cpu_offloading
:
mark_not_offload
(
*
weights_fp8
,
*
weights
)
if
is_grad_enabled
:
ctx
.
weight_quantizers
=
weight_quantizers
ctx
.
weights_shape_1
=
weights
[
0
].
shape
[
1
]
...
...
@@ -208,10 +221,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
inputmats
=
[
None
]
*
num_gemms
if
inp
.
requires_grad
:
for
weight
in
weights_fp8
:
if
isinstance
(
weight
,
QuantizedTensorStorage
):
weight
.
update_usage
(
columnwise_usage
=
True
)
for
i
in
range
(
num_gemms
):
weights
[
i
].
offloading_activation
=
False
...
...
@@ -322,9 +331,9 @@ class _GroupedLinear(torch.autograd.Function):
if
ctx
.
fine_grained_activation_offloading
:
origin_weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
if
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
N
):
origin_weights
[
i
].
main_grad
=
main_grads
[
i
]
if
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
N
):
origin_weights
[
i
].
main_grad
=
main_grads
[
i
]
# Preprocess grad output
grad_output_view
=
grad_output
.
contiguous
().
view
(
-
1
,
grad_output
.
shape
[
-
1
])
...
...
@@ -385,13 +394,11 @@ class _GroupedLinear(torch.autograd.Function):
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
,
)
for
weight
,
quantizer
in
zip
(
weights
,
ctx
.
weight_quantizers
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensorStorage
):
weight
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
)
# Make sure weights are available in column-wise format
# for dgrad computation.
for
weight
in
weights
:
if
isinstance
(
weight
,
QuantizedTensorStorage
):
weight
.
update_usage
(
columnwise_usage
=
True
)
general_grouped_gemm
(
weights
,
grad_output
,
...
...
@@ -880,7 +887,7 @@ class GroupedLinear(TransformerEngineBaseModule):
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
if
not
self
.
need_backward_dw
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_wgrad"
):
(
_
,
grad_biases_
,
_
),
tensor_list
=
self
.
wgrad_store
.
pop
()
...
...
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment