Unverified Commit fe8fad59 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Bunch of fixes for cpu offloading (#2535)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2f8ae81c
...@@ -54,9 +54,13 @@ gc.disable() ...@@ -54,9 +54,13 @@ gc.disable()
class Utils: class Utils:
# Tensor used for simulating long-running GPU work in long_job()
tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16) tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16)
_B = 64 # Test tensor dimensions: _B x _S x _D = 128 x 512 x 256 = 16,777,216 elements
_S = 256 # This exceeds the 256K element threshold for offloading (cpu_offload.py line 443).
# For quantized tensors, scale_inv tensors (~524K elements for block scaling) also exceed threshold.
_B = 128
_S = 512
_H = 4 _H = 4
_D = 256 _D = 256
...@@ -395,6 +399,9 @@ class TestsDefaultOffloadSynchronizer: ...@@ -395,6 +399,9 @@ class TestsDefaultOffloadSynchronizer:
offload_synchronizer.push_tensor(x1) offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1) offload_synchronizer.push_tensor(x1)
offload_synchronizer.push_tensor(x1) offload_synchronizer.push_tensor(x1)
# Verify x1 is not corrupted after pushing (important for QuantizedTensor)
if recipe is not None:
x1.dequantize() # Should not raise - tensor should still be valid
offload_synchronizer.fwd_step() offload_synchronizer.fwd_step()
# Only one copy of tensor on cpu is allocated. # Only one copy of tensor on cpu is allocated.
assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1) assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1)
......
...@@ -19,6 +19,7 @@ import transformer_engine.pytorch.cpu_offload_v1 as v1_code_path ...@@ -19,6 +19,7 @@ import transformer_engine.pytorch.cpu_offload_v1 as v1_code_path
from .quantized_tensor import ( from .quantized_tensor import (
restore_from_saved, restore_from_saved,
prepare_for_saving, prepare_for_saving,
QuantizedTensor,
) )
...@@ -255,6 +256,8 @@ class OffloadableLayerState: ...@@ -255,6 +256,8 @@ class OffloadableLayerState:
Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream. Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream.
Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded. Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded.
This event is recorded in the start_offload or push_tensor call. This event is recorded in the start_offload or push_tensor call.
Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor).
""" """
self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"]) self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"])
self.state = "offload_started" self.state = "offload_started"
...@@ -275,19 +278,18 @@ class OffloadableLayerState: ...@@ -275,19 +278,18 @@ class OffloadableLayerState:
with torch.cuda.stream(self.offload_stream): with torch.cuda.stream(self.offload_stream):
if allocate_cpu_buffers: if allocate_cpu_buffers:
# empty_like is defined also for QuantizedTensors
offloaded_tensor = torch.empty_like( offloaded_tensor = torch.empty_like(
tensor, device=torch.device("cpu"), pin_memory=True tensor, device=torch.device("cpu"), pin_memory=True
) )
self.cpu_tensor_group.tensor_list.append(offloaded_tensor) self.cpu_tensor_group.tensor_list.append(offloaded_tensor)
else: else:
assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, ( offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
assert offloaded_tensor.shape == tensor.shape, (
"CPU buffer shape does not match the offloaded tensor shape:" "CPU buffer shape does not match the offloaded tensor shape:"
f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} " f" {offloaded_tensor.shape} != {tensor.shape} "
" Make sure that tensor shaped do not change between" "Make sure that tensor shapes do not change between"
" iterations if retain_pinned_cpu_buffers is True." " iterations if retain_pinned_cpu_buffers is True."
) )
offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id]
offloaded_tensor.copy_(tensor, non_blocking=True) offloaded_tensor.copy_(tensor, non_blocking=True)
# aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated,
...@@ -318,6 +320,9 @@ class OffloadableLayerState: ...@@ -318,6 +320,9 @@ class OffloadableLayerState:
""" """
Start reloading of tensors. Start reloading of tensors.
It allocates new tensors on GPU and puts copy from CPU tasks on offload stream. It allocates new tensors on GPU and puts copy from CPU tasks on offload stream.
Note: tensor_list only contains regular tensors (QuantizedTensors are decomposed in push_tensor
and reconstructed in pop_tensor).
""" """
self._validate_state(func_name="start_reload", allowed_states=["offload_finished"]) self._validate_state(func_name="start_reload", allowed_states=["offload_finished"])
self.state = "reload_started" self.state = "reload_started"
...@@ -330,7 +335,6 @@ class OffloadableLayerState: ...@@ -330,7 +335,6 @@ class OffloadableLayerState:
# cannot move tensors from pool of one stream to another without # cannot move tensors from pool of one stream to another without
# calling cudaFree and cudaMalloc again. # calling cudaFree and cudaMalloc again.
# empty_like is defined also for QuantizedTensors.
reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda")) reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda"))
self.offload_stream.wait_stream(torch.cuda.current_stream()) self.offload_stream.wait_stream(torch.cuda.current_stream())
...@@ -347,16 +351,29 @@ class OffloadableLayerState: ...@@ -347,16 +351,29 @@ class OffloadableLayerState:
self.bwd_gpu_tensor_group self.bwd_gpu_tensor_group
) )
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
""" """
It is called when a tensor is saved for backward pass. It is called when a tensor is saved for backward pass.
If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group. If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group.
If tensor is not offloaded, returns the tensor itself. If tensor is not offloaded, returns the tensor itself.
For QuantizedTensor, returns (list of push results for each component, tensor_objs) tuple.
""" """
self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"])
if self._check_if_offload(tensor): if self._check_if_offload(tensor):
# For QuantizedTensor: decompose into component tensors, push each one recursively
if isinstance(tensor, QuantizedTensor):
# Make a copy because prepare_for_saving modifies the object (sets fields to None)
tensor_copy = tensor.detach()
# Inline prepare_for_saving logic - QuantizedTensor is a torch.Tensor subclass,
# so the generic prepare_for_saving would not call tensor.prepare_for_saving()
saved_tensors, tensor_obj = tensor_copy.prepare_for_saving()
push_results = [
self.push_tensor(t) if t is not None else None for t in saved_tensors
]
return (push_results, [tensor_obj])
self.fwd_gpu_tensor_group.tensor_list.append(tensor) self.fwd_gpu_tensor_group.tensor_list.append(tensor)
# The group is processed and offloaded at the end of the forward pass of current layer. # The group is processed and offloaded at the end of the forward pass of current layer.
# To enable offloading of tensors faster we use self.offload_stream and record # To enable offloading of tensors faster we use self.offload_stream and record
...@@ -370,23 +387,39 @@ class OffloadableLayerState: ...@@ -370,23 +387,39 @@ class OffloadableLayerState:
return len(self.fwd_gpu_tensor_group.tensor_list) - 1 return len(self.fwd_gpu_tensor_group.tensor_list) - 1
return tensor return tensor
def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: def pop_tensor(
self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list]
) -> torch.Tensor:
""" """
It is called when a tensor is used in backward pass. It is called when a tensor is used in backward pass.
Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish. Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish.
For QuantizedTensor (tuple input), reconstructs from component tensors.
""" """
self._validate_state( self._validate_state(
func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"] func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"]
) )
# 1. tensor not offloaded # 1. tensor not offloaded (regular tensor returned as-is from push)
if isinstance(tensor_or_tensor_id, torch.Tensor): if isinstance(tensor_or_tensor_id, torch.Tensor):
return tensor_or_tensor_id return tensor_or_tensor_id
# 2. the layer was not offloaded at all
# 2. QuantizedTensor case: tuple of (push_results, tensor_objs)
if isinstance(tensor_or_tensor_id, tuple):
push_results, tensor_objs = tensor_or_tensor_id
# Recursively pop each component
reloaded_tensors = [
self.pop_tensor(pr) if pr is not None else None for pr in push_results
]
# Inline restore_from_saved - tensor_objs[0] is the QuantizedTensor copy
tensor_obj = tensor_objs[0]
tensor_obj.restore_from_saved(reloaded_tensors)
return tensor_obj
# 3. Regular tensor index case
if self.state == "not_offloaded": if self.state == "not_offloaded":
return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id]
# 3. the layer was offloaded # 4. the layer was offloaded
assert self.state == "reload_started" assert self.state == "reload_started"
# wait for the tensor to be reloaded # wait for the tensor to be reloaded
torch.cuda.current_stream().wait_event( torch.cuda.current_stream().wait_event(
...@@ -406,6 +439,10 @@ class OffloadableLayerState: ...@@ -406,6 +439,10 @@ class OffloadableLayerState:
""" """
Check if tensor needs to be offloaded. Check if tensor needs to be offloaded.
""" """
# Only offload tensors with at least 256k elements (~1MB for float32)
if t.numel() < 256 * 1024:
return False
if ( if (
not isinstance(t, torch.nn.Parameter) not isinstance(t, torch.nn.Parameter)
and not getattr(t, "_TE_do_not_offload", False) and not getattr(t, "_TE_do_not_offload", False)
...@@ -418,7 +455,6 @@ class OffloadableLayerState: ...@@ -418,7 +455,6 @@ class OffloadableLayerState:
" this tensor will be skipped." " this tensor will be skipped."
) )
return False return False
return True return True
return False return False
...@@ -488,11 +524,13 @@ class OffloadSynchronizer: ...@@ -488,11 +524,13 @@ class OffloadSynchronizer:
self.previous_bwd_layer_id = layer_num self.previous_bwd_layer_id = layer_num
self.current_layer_id = layer_num self.current_layer_id = layer_num
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""Default push tensor method""" """Default push tensor method"""
return self.layer_states[self.num_of_fwds].push_tensor(tensor) return self.layer_states[self.num_of_fwds].push_tensor(tensor)
def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: def pop_tensor(
self, tensor_or_tensor_id: torch.Tensor | int | tuple[list, list]
) -> torch.Tensor:
"""Default pop tensor method""" """Default pop tensor method"""
return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id) return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id)
...@@ -592,6 +630,12 @@ class DefaultOffloadSynchronizer(OffloadSynchronizer): ...@@ -592,6 +630,12 @@ class DefaultOffloadSynchronizer(OffloadSynchronizer):
for layer in self.start_reload_map[layer_num]: for layer in self.start_reload_map[layer_num]:
self.layer_states[layer].start_reload() self.layer_states[layer].start_reload()
def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor | tuple[list, list]:
"""Push tensor - skip processing if layer won't be offloaded to reduce CPU overhead."""
if not self.offload_layer_map.get(self.num_of_fwds, False):
return tensor
return self.layer_states[self.num_of_fwds].push_tensor(tensor)
class ManualOffloadSynchronizer(OffloadSynchronizer): class ManualOffloadSynchronizer(OffloadSynchronizer):
""" """
......
...@@ -254,7 +254,8 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten ...@@ -254,7 +254,8 @@ std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &ten
std::vector<py::object> split_quantize(const at::Tensor &tensor, std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections, const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list); std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation = false);
/*************************************************************************************************** /***************************************************************************************************
* Bias gradient fusions * Bias gradient fusions
......
...@@ -1095,7 +1095,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, ...@@ -1095,7 +1095,8 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input,
std::vector<py::object> split_quantize(const at::Tensor &tensor, std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<size_t> &split_sections, const std::vector<size_t> &split_sections,
std::vector<py::handle> quantizer_list) { std::vector<py::handle> quantizer_list,
bool disable_bulk_allocation) {
init_extension(); init_extension();
// Check number of tensors // Check number of tensors
...@@ -1147,22 +1148,24 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -1147,22 +1148,24 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 };
AllocationMethod allocation_method = AllocationMethod::UNFUSED; AllocationMethod allocation_method = AllocationMethod::UNFUSED;
QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED;
if (std::all_of(quantizer_list.begin(), quantizer_list.end(), if (!disable_bulk_allocation) {
[](const py::handle &quantizer) -> bool { if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr()); [](const py::handle &quantizer) -> bool {
})) { return detail::IsFloat8BlockwiseQuantizers(quantizer.ptr());
allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE; })) {
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), allocation_method = AllocationMethod::BULK_FP8_BLOCKWISE;
[](const py::handle &quantizer) -> bool { } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
return detail::IsMXFP8Quantizers(quantizer.ptr()); [](const py::handle &quantizer) -> bool {
})) { return detail::IsMXFP8Quantizers(quantizer.ptr());
allocation_method = AllocationMethod::BULK_MXFP8; })) {
} else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), allocation_method = AllocationMethod::BULK_MXFP8;
[](const py::handle &quantizer) -> bool { } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(),
return detail::IsNVFP4Quantizers(quantizer.ptr()); [](const py::handle &quantizer) -> bool {
})) { return detail::IsNVFP4Quantizers(quantizer.ptr());
allocation_method = AllocationMethod::BULK_NVFP4; })) {
quantization_method = QuantizationMethod::FUSED_NVFP4; allocation_method = AllocationMethod::BULK_NVFP4;
quantization_method = QuantizationMethod::FUSED_NVFP4;
}
} }
// Allocate output tensors // Allocate output tensors
......
...@@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -248,7 +248,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
m.def("split_quantize", &transformer_engine::pytorch::split_quantize, m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list")); py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false);
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM"); "Grouped GEMM");
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
......
...@@ -143,7 +143,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -143,7 +143,12 @@ class _GroupedLinear(torch.autograd.Function):
inp_view = inp.reshape(-1, in_features) inp_view = inp.reshape(-1, in_features)
inputmats: list inputmats: list
if fp8 and not debug: if fp8 and not debug:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) # Disable bulk allocation when CPU offloading is active: offloading skips small
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
)
elif debug: elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize( inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, input_quantizers, m_splits, activation_dtype inp_view, input_quantizers, m_splits, activation_dtype
......
...@@ -428,7 +428,8 @@ class _Linear(torch.autograd.Function): ...@@ -428,7 +428,8 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module # weights if weights are externally touched outside this module
ctx.weight_object = weight ctx.weight_object = weight
mark_not_offload(weight, weightmat, bias) if cpu_offloading:
mark_not_offload(weight, weightmat, bias)
# TODO(ksivamani): Check memory usage # TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat, saved_inputmat,
......
...@@ -14,6 +14,7 @@ import torch ...@@ -14,6 +14,7 @@ import torch
from torch.distributed._tensor import DTensor from torch.distributed._tensor import DTensor
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
...@@ -372,10 +373,12 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -372,10 +373,12 @@ class FusedAdam(torch.optim.Optimizer):
store_param_remainders (bool): Store only trailing remainder bits. store_param_remainders (bool): Store only trailing remainder bits.
""" """
dtype = self.name_to_dtype_map[state_name] dtype = self.name_to_dtype_map[state_name]
# Handle QuantizedTensor by dequantizing first
param_for_empty = param.dequantize() if isinstance(param, QuantizedTensor) else param
if store_param_remainders: if store_param_remainders:
data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) data = torch.zeros_like(param_for_empty, dtype=torch.int16)
else: else:
data = torch.empty(param.shape, dtype=dtype, device=param.device) data = torch.empty_like(param_for_empty, dtype=dtype)
if zero_buffer: if zero_buffer:
data.zero_() data.zero_()
......
...@@ -20,11 +20,6 @@ from transformer_engine.pytorch.tensor._quantization_helpers import ( ...@@ -20,11 +20,6 @@ from transformer_engine.pytorch.tensor._quantization_helpers import (
_stride_from_shape, _stride_from_shape,
) )
_quantized_tensor_cpu_supported_ops = (
torch.ops.aten.empty_like.default,
torch.ops.aten.copy_.default,
)
class QuantizedTensorStorage: class QuantizedTensorStorage:
r"""Base class for all TensorStorage classes. r"""Base class for all TensorStorage classes.
...@@ -539,15 +534,6 @@ class QuantizedTensor(torch.Tensor): ...@@ -539,15 +534,6 @@ class QuantizedTensor(torch.Tensor):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
def check_if_cpu(arg):
if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu":
assert (
func in _quantized_tensor_cpu_supported_ops
), f"QuantizedTensor on CPU does not support this operation: {func}"
return arg
args = tree_map(check_if_cpu, args)
# Do not force the QuantizedTensor type on the returned tensor # Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs) return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment