Unverified Commit 6b847a9a authored by JiLi's avatar JiLi Committed by GitHub
Browse files

Optimize: Cache CUDA device to reduce redundant calls during tensor l… (#8996)

parent 473400e4
...@@ -895,8 +895,12 @@ class ModelRunner: ...@@ -895,8 +895,12 @@ class ModelRunner:
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]], named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
load_format: Optional[str] = None, load_format: Optional[str] = None,
): ):
monkey_patch_torch_reductions()
# We need to get device after patch otherwise the device would be wrong
infered_device = torch.cuda.current_device()
named_tensors = [ named_tensors = [
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank)) (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device))
for name, tensor in named_tensors for name, tensor in named_tensors
] ]
if load_format == "direct": if load_format == "direct":
...@@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso ...@@ -1809,11 +1813,10 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
default_weight_loader(params_dict[name], tensor) default_weight_loader(params_dict[name], tensor)
def _unwrap_tensor(tensor, tp_rank): def _unwrap_tensor(tensor, tp_rank, device):
if isinstance(tensor, LocalSerializedTensor): if isinstance(tensor, LocalSerializedTensor):
monkey_patch_torch_reductions()
tensor = tensor.get(tp_rank) tensor = tensor.get(tp_rank)
return tensor.to(torch.cuda.current_device()) return tensor.to(device)
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment