Unverified Commit 64891899 authored by buptzyb's avatar buptzyb Committed by GitHub
Browse files

Optimize CUDA Graph memory, FP8 wrapper, and uneven PP support (#1234)



* Reuse cudagraph input and output tensor memory
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* Wrap _make_graphed_callables with fp8
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* add uneven pp support
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove no grad tensor reuse
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* simplify TensorWrapper
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* Format and add comments
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* Revert FP8 wrapper
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply comment tweaks from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Solve lint
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* remove unused params
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update comment
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1dd8f62d
...@@ -21,6 +21,7 @@ from .fp8 import ( ...@@ -21,6 +21,7 @@ from .fp8 import (
from .distributed import get_all_rng_states, graph_safe_rng_available from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation from .ops.op import BasicOperation
from .utils import make_weak_ref
__all__ = ["make_graphed_callables"] __all__ = ["make_graphed_callables"]
...@@ -63,8 +64,10 @@ def _make_graphed_callables( ...@@ -63,8 +64,10 @@ def _make_graphed_callables(
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
_num_layers_per_chunk: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None, pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False, retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
) -> SingleOrTuple[Callable]: ) -> SingleOrTuple[Callable]:
""" """
Helper method for `make_graphed_callables` Helper method for `make_graphed_callables`
...@@ -110,29 +113,113 @@ def _make_graphed_callables( ...@@ -110,29 +113,113 @@ def _make_graphed_callables(
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py.
# Note: The model is assumed to consist of layers # Note: The model is assumed to consist of layers
# (corresponding to callables) that are grouped into # (corresponding to callables) that are grouped into
# equally-sized model chunks. _order is a list of chunk # model chunks. _num_layers_per_chunk is a list of integers
# indices (1-indexed) that indicates the order in which the # that indicates the number of layers in each model chunk.
# layers are evaluated. Positive values indicate forward # _order is a list of chunk indices (1-indexed) that
# passes and negative values indicate backward passes. Each # indicates the order in which the layers are evaluated.
# Positive values indicate forward passes and negative
# values indicate backward passes. Each
# entry in sample_args corresponds to one of the forward # entry in sample_args corresponds to one of the forward
# passes. # passes.
num_model_chunks = max(_order) num_model_chunks = max(_order)
num_microbatches = len(_order) // num_model_chunks // 2 num_microbatches = len(_order) // num_model_chunks // 2
assert num_model_chunks * num_microbatches * 2 == len(_order) assert num_model_chunks * num_microbatches * 2 == len(_order)
assert len(sample_args) * 2 >= len(_order) and (
len(sample_args) * 2 % len(_order) == 0 # Determine number of layers in each model chunk.
), f"{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0" if _num_layers_per_chunk is None:
num_layers = len(sample_args) // num_model_chunks // num_microbatches assert len(sample_args) * 2 >= len(_order) and (
assert len(callables) == num_model_chunks * num_layers, ( len(sample_args) * 2 % len(_order) == 0
f"Callables should have ({num_model_chunks * num_layers}) " ), (
f"{len(sample_args)} * 2 >= {len(_order)} and {len(sample_args)} * 2 %"
f" {len(_order)} == 0"
)
num_layers = len(sample_args) // num_model_chunks // num_microbatches
_num_layers_per_chunk = [num_layers] * num_model_chunks
else:
assert (
isinstance(_num_layers_per_chunk, int)
or len(_num_layers_per_chunk) == num_model_chunks
), (
"If _num_layers_per_chunk is provided, it must be an integer or a list of"
f" {num_model_chunks} integers, but got {_num_layers_per_chunk}."
)
if isinstance(_num_layers_per_chunk, int):
_num_layers_per_chunk = [_num_layers_per_chunk] * num_model_chunks
total_num_layers = sum(_num_layers_per_chunk)
assert len(callables) == total_num_layers, (
f"Callables should have ({total_num_layers}) "
+ f"entries when order input is provided but got {len(callables)}." + f"entries when order input is provided but got {len(callables)}."
) )
assert len(sample_args) == num_model_chunks * num_microbatches * num_layers, ( assert len(sample_args) == total_num_layers * num_microbatches, (
f"Expected {num_model_chunks * num_microbatches}" f"Expected {total_num_layers * num_microbatches}"
+ f"args tuple, but got {len(sample_args)}." + f"args tuple, but got {len(sample_args)}."
) )
# Calculate the starting index of each chunk in callables for future use.
_prefix_num_layers = [0]
for m_chunk in range(num_model_chunks):
num_layers = _num_layers_per_chunk[m_chunk]
_prefix_num_layers.append(_prefix_num_layers[-1] + num_layers)
assert len(sample_kwargs) == len(sample_args) assert len(sample_kwargs) == len(sample_args)
# Check reuse graph conditions and reorganize sample_args and sample_kwargs.
# Note: When capturing a graph, we hold onto the args and kwargs so we have static buffers
# when the graph is replayed. If two model chunk microbatches have no overlap between their
# forward and backward, then we can reduce memory usage by reusing the same static buffers.
if _reuse_graph_input_output_buffers:
assert (
_order is not None
), "`_order` must be provided when `_reuse_graph_input_output_buffers` is True."
assert (
is_training
), "`_reuse_graph_input_output_buffers` is only available in training mode."
assert isinstance(
sample_args, list
), "sample_args must be a list for _reuse_graph_input_output_buffers."
len_args = len(sample_args[0])
for i, arg in enumerate(sample_args):
assert len_args == len(
arg
), "Arguments must have same length and shape for `_reuse_graph_input_output_buffers`."
len_kwargs = len(sample_kwargs[0])
assert isinstance(
sample_kwargs, list
), "sample_kwargs must be a list for _reuse_graph_input_output_buffers."
for i, kwarg in enumerate(sample_kwargs):
assert len_kwargs == len(kwarg), (
"Keyword arguments must have same length and shape for"
" `_reuse_graph_input_output_buffers`."
)
# Reorganize args and kwargs for input tensor reuse.
fwd_sample_qs = {}
consumed_sample_q = []
fwd_idx = [0] * num_model_chunks
for c_id in _order:
m_chunk = abs(c_id) - 1
if c_id > 0:
sample_start_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk]
)
fwd_sample_idx = [
sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk])
]
fwd_sample_qs[m_chunk] = fwd_sample_qs.get(m_chunk, []) + fwd_sample_idx
for per_callable_fwd_idx in fwd_sample_idx:
if consumed_sample_q:
reuse_fwd_idx = consumed_sample_q.pop(0)
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx]
sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx]
fwd_idx[m_chunk] += 1
else:
num_consumed_samples = min(
len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk]
)
consumed_sample_q += fwd_sample_qs[m_chunk][:num_consumed_samples]
fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:]
if fp8_weight_caching: if fp8_weight_caching:
# Initialize flag that controls FP8 weight updates # Initialize flag that controls FP8 weight updates
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
...@@ -185,10 +272,13 @@ def _make_graphed_callables( ...@@ -185,10 +272,13 @@ def _make_graphed_callables(
per_callable_module_params = [] per_callable_module_params = []
for m_chunk in range(num_model_chunks): for m_chunk in range(num_model_chunks):
for _ in range(num_microbatches): for _ in range(num_microbatches):
for l_no in range(num_layers): for l_no in range(_num_layers_per_chunk[m_chunk]):
per_callable_module_params.append( per_callable_module_params.append(
tuple(callables[m_chunk * num_layers + l_no].parameters()) tuple(callables[_prefix_num_layers[m_chunk] + l_no].parameters())
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) if isinstance(
callables[_prefix_num_layers[m_chunk] + l_no],
torch.nn.Module,
)
else () else ()
) )
assert len(per_callable_module_params) == len(flatten_sample_args) assert len(per_callable_module_params) == len(flatten_sample_args)
...@@ -227,10 +317,10 @@ def _make_graphed_callables( ...@@ -227,10 +317,10 @@ def _make_graphed_callables(
for c_id in _order: for c_id in _order:
if c_id > 0: if c_id > 0:
m_chunk = c_id - 1 m_chunk = c_id - 1
for l_no in range(num_layers): for l_no in range(_num_layers_per_chunk[m_chunk]):
func = callables[m_chunk * num_layers + l_no] func = callables[_prefix_num_layers[m_chunk] + l_no]
func_idx = (m_chunk * num_microbatches * num_layers) + ( func_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
fwd_idx[m_chunk] * num_layers + l_no fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
) )
warmup_func_idx.append(func_idx) warmup_func_idx.append(func_idx)
warmup_func.append(func) warmup_func.append(func)
...@@ -255,7 +345,7 @@ def _make_graphed_callables( ...@@ -255,7 +345,7 @@ def _make_graphed_callables(
args = sample_args[func_idx] args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx] kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx]
for _ in range(num_warmup_iters): for warmup_iter in range(num_warmup_iters):
hooks = [] hooks = []
for module in func.modules(): for module in func.modules():
hook = module.register_forward_hook(hook_fn) hook = module.register_forward_hook(hook_fn)
...@@ -271,6 +361,34 @@ def _make_graphed_callables( ...@@ -271,6 +361,34 @@ def _make_graphed_callables(
only_inputs=True, only_inputs=True,
allow_unused=allow_unused_input, allow_unused=allow_unused_input,
) )
# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
# registered to these params are not wrongly triggered.
num_required_grad_sample_args = sum(
arg.requires_grad for arg in flatten_sample_args[func_idx]
)
required_grad_input_idx = []
for i, arg in enumerate(static_input_surface):
if arg.requires_grad:
required_grad_input_idx.append(i)
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if (
grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args
):
module_params_with_grad.append(static_input_surface[inputs_idx])
if len(module_params_with_grad) != len(per_callable_module_params[func_idx]):
assert warmup_iter == 0, (
"no-grad params should only be used as inputs in the first warmup"
" iteration"
)
per_callable_module_params[func_idx] = tuple(module_params_with_grad)
static_input_surface = flatten_sample_args[func_idx] + tuple(
module_params_with_grad
)
per_callable_static_input_surfaces[func_idx] = static_input_surface
else: else:
grad_inputs = None grad_inputs = None
del outputs, grad_inputs del outputs, grad_inputs
...@@ -292,14 +410,16 @@ def _make_graphed_callables( ...@@ -292,14 +410,16 @@ def _make_graphed_callables(
per_callable_static_grad_inputs = [None] * len(flatten_sample_args) per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks
static_grad_outputs = None
previous_per_callable_bwd_idx = None
for c_id in _order: for c_id in _order:
if c_id > 0: if c_id > 0:
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
m_chunk = c_id - 1 m_chunk = c_id - 1
for l_no in range(num_layers): for l_no in range(_num_layers_per_chunk[m_chunk]):
func = callables[m_chunk * num_layers + l_no] func = callables[_prefix_num_layers[m_chunk] + l_no]
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + ( per_callable_fwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
fwd_idx[m_chunk] * num_layers + l_no fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
) )
args = sample_args[per_callable_fwd_idx] args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx]
...@@ -314,17 +434,20 @@ def _make_graphed_callables( ...@@ -314,17 +434,20 @@ def _make_graphed_callables(
else: else:
# Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
m_chunk = -c_id - 1 m_chunk = -c_id - 1
for l_no in list(reversed(range(num_layers))): for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) + ( per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * num_layers + l_no bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
) )
static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx] static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
static_outputs = per_callable_static_outputs[per_callable_bwd_idx] static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad # For now, assumes all static_outputs require grad
static_grad_outputs = tuple( if not _reuse_graph_input_output_buffers or static_grad_outputs is None:
torch.empty_like(o) if o.requires_grad else None for o in static_outputs # Note for _reuse_graph_input_output_buffers: grad output is only used
) # within backward, so we can reuse the same static buffers every time.
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training: if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool): with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad( grad_inputs = torch.autograd.grad(
...@@ -350,6 +473,30 @@ def _make_graphed_callables( ...@@ -350,6 +473,30 @@ def _make_graphed_callables(
per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
# Weak ref the static outputs and static grad inputs that are no longer needed
# in the following steps. These two type of tensors are both in cudagraph
# mempool, so we just deallocate them and let PyTorch's memory allocator
# reuse them elsewhere.
if _reuse_graph_input_output_buffers:
# Weak ref the static outputs of the forward pass of this backward. It's
# no longer needed after the corresponding backward graph is built up.
per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref(
static_outputs
)
# Weak ref the static grad inputs of the previous backward pass.
# Note: After a backward pass, we assume Mcore will send the
# grad input to another pipeline parallel rank and that the
# communication is finished before the end of the next backward
# pass.
if previous_per_callable_bwd_idx is not None:
per_callable_static_grad_inputs[previous_per_callable_bwd_idx] = (
make_weak_ref(
per_callable_static_grad_inputs[previous_per_callable_bwd_idx]
)
)
previous_per_callable_bwd_idx = per_callable_bwd_idx
bwd_idx[m_chunk] += 1 bwd_idx[m_chunk] += 1
else: else:
# Capture forward graphs # Capture forward graphs
...@@ -634,8 +781,10 @@ def make_graphed_callables( ...@@ -634,8 +781,10 @@ def make_graphed_callables(
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
_num_layers_per_chunk: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None, pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False, retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
) -> Union[Callable, Tuple[Callable, ...]]: ) -> Union[Callable, Tuple[Callable, ...]]:
""" """
Make CUDA graph version of Transformer Engine modules Make CUDA graph version of Transformer Engine modules
...@@ -664,6 +813,11 @@ def make_graphed_callables( ...@@ -664,6 +813,11 @@ def make_graphed_callables(
this graph may share memory with the indicated pool. this graph may share memory with the indicated pool.
retain_graph_in_backward: bool, default = `False` retain_graph_in_backward: bool, default = `False`
Whether to set retain_graph=True in backward graph capture. Whether to set retain_graph=True in backward graph capture.
_reuse_graph_input_output_buffers: bool, default = `False`
Reduce memory usage by reusing input/output data buffers between
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape.
FP8-related parameters FP8-related parameters
---------------------- ----------------------
...@@ -702,10 +856,17 @@ def make_graphed_callables( ...@@ -702,10 +856,17 @@ def make_graphed_callables(
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
# FP8 wrapper. # FP8 wrapper.
old_call_funcs = {}
def wrap_autocast(block): def wrap_autocast(block):
old_forward = block.forward block_cls = type(block)
if block_cls in old_call_funcs:
return
old_call_funcs[block_cls] = block_cls.__call__
def forward_func(*args, **kwargs): # Wrap the original call function of the module class.
def call_func(*args, **kwargs):
with fp8_autocast( with fp8_autocast(
enabled=fp8_enabled, enabled=fp8_enabled,
calibrating=fp8_calibrating, calibrating=fp8_calibrating,
...@@ -713,10 +874,10 @@ def make_graphed_callables( ...@@ -713,10 +874,10 @@ def make_graphed_callables(
fp8_group=fp8_group, fp8_group=fp8_group,
_graph=True, _graph=True,
): ):
outputs = old_forward(*args, **kwargs) outputs = old_call_funcs[block_cls](*args, **kwargs)
return outputs return outputs
block.forward = forward_func block_cls.__call__ = call_func
forward_funcs = [] forward_funcs = []
for module in modules: for module in modules:
...@@ -747,8 +908,10 @@ def make_graphed_callables( ...@@ -747,8 +908,10 @@ def make_graphed_callables(
fp8_weight_caching=fp8_weight_caching, fp8_weight_caching=fp8_weight_caching,
sample_kwargs=sample_kwargs, sample_kwargs=sample_kwargs,
_order=_order, _order=_order,
_num_layers_per_chunk=_num_layers_per_chunk,
pool=pool, pool=pool,
retain_graph_in_backward=retain_graph_in_backward, retain_graph_in_backward=retain_graph_in_backward,
_reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers,
) )
# Ensures warmup does not affect numerics for ops such as dropout. # Ensures warmup does not affect numerics for ops such as dropout.
...@@ -758,6 +921,10 @@ def make_graphed_callables( ...@@ -758,6 +921,10 @@ def make_graphed_callables(
else: else:
torch.cuda.set_rng_state(original_rng_states) torch.cuda.set_rng_state(original_rng_states)
# Remove FP8 wrapper.
for module_cls, old_call in old_call_funcs.items():
module_cls.__call__ = old_call
# Restore FP8 state. # Restore FP8 state.
restore_fp8_tensors(modules, saved_fp8_tensors) restore_fp8_tensors(modules, saved_fp8_tensors)
......
...@@ -7,7 +7,7 @@ from __future__ import annotations ...@@ -7,7 +7,7 @@ from __future__ import annotations
import functools import functools
import math import math
import os import os
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -625,3 +625,111 @@ if torch_version() >= (2, 4, 0): ...@@ -625,3 +625,111 @@ if torch_version() >= (2, 4, 0):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda") gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else: else:
gpu_autocast_ctx = torch.cuda.amp.autocast gpu_autocast_ctx = torch.cuda.amp.autocast
_torch_dtype_to_np_typestr_dict = {
torch.float16: "<f2",
torch.float32: "<f4",
torch.int64: "<i8",
torch.int32: "<i4",
torch.int8: "|i1",
torch.float8_e4m3fn: "|i1",
torch.qint8: "|u1",
torch.bool: "|b1",
torch.bfloat16: "<f2",
}
class _WeakRefTensor:
"""
A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead.
"""
def __init__(
self,
data_ptr: int,
dtype: torch.dtype,
shape: Sequence[int],
):
self._data_ptr = data_ptr
self.dtype = dtype
self.shape = shape
def data_ptr(self):
"""Data pointer of the tensor."""
return self._data_ptr
@property
def dtype(self):
"""Dtype of the tensor."""
return self._dtype
@property
def shape(self):
"""Shape of the tensor."""
return getattr(self, "_shape", None)
@dtype.setter
def dtype(self, dtype: torch.dtype):
self._dtype = dtype
@shape.setter
def shape(self, shape: Sequence[int]):
self._shape = tuple(int(i) for i in shape)
def numel(self):
"""Number of elements in the tensor."""
return np.prod(self.shape)
@property
def __cuda_array_interface__(self):
return {
"shape": self.shape,
"typestr": self.torch_dtype_to_np_typestr(),
"data": (self.data_ptr() if self.numel() > 0 else 0, False),
"version": 3,
}
def torch_dtype_to_np_typestr(self):
"""Convert PyTorch dtype to numpy typestr."""
ret = _torch_dtype_to_np_typestr_dict.get(self.dtype)
assert ret is not None, f"Unsupported dtype: {self.dtype}"
return ret
def make_weak_ref(x):
"""
This function is to make a weak reference to the input so that the memory can be released.
"""
def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torch.Tensor:
"""
This function is to convert the `_WeakRefTensor` to torch.Tensor.
"""
if isinstance(tensor, torch.Tensor):
return tensor
old_ptr = tensor.data_ptr()
new_tensor = torch.as_tensor(tensor).view(tensor.dtype)
new_ptr = new_tensor.data_ptr()
if old_ptr != new_ptr:
raise RuntimeError("Data pointer mismatch after converting to torch.Tensor")
return new_tensor
if isinstance(x, torch.Tensor):
return (
convert_to_torch_tensor(_WeakRefTensor(x.data_ptr(), x.dtype, x.shape))
if x.is_cuda
else x
)
if isinstance(x, tuple):
return tuple(make_weak_ref(i) for i in x)
if isinstance(x, list):
return [make_weak_ref(i) for i in x]
if isinstance(x, dict):
return {k: make_weak_ref(v) for k, v in x.items()}
if isinstance(x, (int, float, bool)):
return x
if x is None:
return None
raise TypeError(f"Invalid type {type(x)} to make weak ref")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment