Unverified Commit 1621b252 authored by Chengji Yao's avatar Chengji Yao Committed by GitHub
Browse files

[TPU] Fix dummy loading OOM (#16372)


Signed-off-by: default avatarChengji Yao <chengjiyao@google.com>
parent a5647971
......@@ -658,8 +658,21 @@ def initialize_dummy_weights(
for param in model.state_dict().values():
if torch.is_floating_point(param):
if current_platform.is_tpu():
# XLA device does not support torch.Generator()
param.uniform_(low, high)
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
# Note: The param.uniform_ function cannot be used in this
# context because it demands more TPU HBM than directly copying
# from a CPU tensor.
# Note: We avoid using torch.rank_like as it doesn't currently
# support the generator argument.
param.copy_((high - low) *
torch.rand(*param.shape,
generator=generator,
dtype=param.dtype,
layout=param.layout,
requires_grad=param.requires_grad,
device="cpu") + low)
torch._sync(param)
continue
generator = torch.Generator(device=param.data.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment