Unverified Commit 1d5d5fae authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge pull request #626 from cyhasuka/main

Feat: Clear cache during weight loading to prevent OOM on GPUs with <=8GB VRAM
parents 3c8c5805 8db6a4d4
...@@ -92,8 +92,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str ...@@ -92,8 +92,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
target_dtype = torch.get_default_dtype() target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}") print(f"loading {translated_key} to {device}")
torch.cuda.empty_cache() # To fit in 16G VRAM. By "wkGCaSS - 知乎 https://zhuanlan.zhihu.com/p/25491611225" torch.cuda.empty_cache()
# weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype) weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights) set_param(module, name, weights)
del weights del weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment