"vscode:/vscode.git/clone" did not exist on "20316346808ada97d6b1bfe86cd5e6b5fe804bcd"
Unverified Commit ce37be7b authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[misc][distributed] add seed to dummy weights (#6491)

parent 7f62077a
......@@ -440,6 +440,7 @@ def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
seed: int = 1234,
) -> None:
"""Initialize model weights with random values.
......@@ -447,17 +448,25 @@ def initialize_dummy_weights(
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
We use per-parameter random seed, so that dummy weights are consistent,
even if the model is partitioned across multiple devices. When the seed
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high).to(dtype)
tmp_param = tmp_param.uniform_(low, high,
generator=generator).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high)
param.uniform_(low, high, generator=generator)
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment