Unverified Commit 47a61c3b authored by Cedric Luo's avatar Cedric Luo Committed by GitHub
Browse files

[Fix] Fix init weights of MultiScaleDeformableAttention (#2158)



* fix tensors on different device

* fix lint

* Update mmcv/ops/multi_scale_deform_attn.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent fb795962
...@@ -235,9 +235,10 @@ class MultiScaleDeformableAttention(BaseModule): ...@@ -235,9 +235,10 @@ class MultiScaleDeformableAttention(BaseModule):
def init_weights(self) -> None: def init_weights(self) -> None:
"""Default initialization for Parameters of Module.""" """Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.) constant_init(self.sampling_offsets, 0.)
device = next(self.parameters()).device
thetas = torch.arange( thetas = torch.arange(
self.num_heads, self.num_heads, dtype=torch.float32,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads) device=device) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init / grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view( grid_init.abs().max(-1, keepdim=True)[0]).view(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment