Unverified Commit 6b847a9a authored by JiLi's avatar JiLi Committed by GitHub
Browse files

Optimize: Cache CUDA device to reduce redundant calls during tensor l… (#8996)

parent 473400e4
......@@ -895,8 +895,12 @@ class ModelRunner:
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
load_format: Optional[str] = None,
):
monkey_patch_torch_reductions()
# We need to get device after patch otherwise the device would be wrong
infered_device = torch.cuda.current_device()
named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
for name, tensor in named_tensors
]
if load_format == "direct":
......@@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
default_weight_loader(params_dict[name], tensor)
def _unwrap_tensor(tensor, tp_rank):
def _unwrap_tensor(tensor, tp_rank, device):
if isinstance(tensor, LocalSerializedTensor):
monkey_patch_torch_reductions()
tensor = tensor.get(tp_rank)
return tensor.to(torch.cuda.current_device())
return tensor.to(device)
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment