Commit c9861a03 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Initialize mean and rstd tensor using x.device

parent 99ea4baa
......@@ -314,8 +314,8 @@ def _layer_norm_fwd(
assert residual_out.stride(-1) == 1
else:
residual_out = None
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
if dropout_p > 0.0:
seeds = torch.randint(
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment