Unverified Commit 66b0d9ee authored by Donggeun Yu's avatar Donggeun Yu Committed by GitHub
Browse files

DeformableDETR two stage support bfloat16 (#30907)

Update modeling_deformable_detr.py
parent 5d0bf59b
......@@ -1616,8 +1616,8 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = meshgrid(
torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device),
torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device),
indexing="ij",
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment