Unverified Commit 9351a179 authored by guyueh1's avatar guyueh1 Committed by GitHub
Browse files

Fix a crash in NeMo 2.0 during module._apply(lambda t: t.cpu()) (#1502)



* Fix a crash with module._apply(lambda t: t.cpu())
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* Add comments
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

* Make sure tensor is moved to dst device before quantizer quantizes
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>

---------
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 87441885
...@@ -484,6 +484,8 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -484,6 +484,8 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# Tensor device # Tensor device
new_device = tensor.device if tensor.is_cuda else self.device new_device = tensor.device if tensor.is_cuda else self.device
if not devices_match(new_device, tensor.device):
tensor = tensor.to(device=new_device)
# Just copy FP8 data if other tensor is Float8Tensor # Just copy FP8 data if other tensor is Float8Tensor
if isinstance(tensor, Float8Tensor): if isinstance(tensor, Float8Tensor):
......
...@@ -368,6 +368,8 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -368,6 +368,8 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# Tensor device # Tensor device
new_device = tensor.device if tensor.is_cuda else self.device new_device = tensor.device if tensor.is_cuda else self.device
if not devices_match(new_device, tensor.device):
tensor = tensor.to(device=new_device)
# Just copy FP8 data if other tensor is MXFP8Tensor # Just copy FP8 data if other tensor is MXFP8Tensor
if isinstance(tensor, MXFP8Tensor): if isinstance(tensor, MXFP8Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment