Unverified Commit 8a0ed0a9 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Fix copies between DETR and DETA (#29037)

parent 5b6fa230
...@@ -627,7 +627,8 @@ class DetaMultiscaleDeformableAttention(nn.Module): ...@@ -627,7 +627,8 @@ class DetaMultiscaleDeformableAttention(nn.Module):
def _reset_parameters(self): def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0) nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads) default_dtype = torch.get_default_dtype()
thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = ( grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0]) (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment