Unverified Commit 5be95934 authored by Guangchen Lin's avatar Guangchen Lin Committed by GitHub
Browse files

[Fix] fix generalized attention fp16 (#1036)

* fix generalized attention fp16

* fix building without gpu error

* add comment

* Cast tensor at initialization
parent 1a66977f
...@@ -170,18 +170,23 @@ class GeneralizedAttention(nn.Module): ...@@ -170,18 +170,23 @@ class GeneralizedAttention(nn.Module):
q_stride, q_stride,
kv_stride, kv_stride,
device, device,
dtype,
feat_dim, feat_dim,
wave_length=1000): wave_length=1000):
h_idxs = torch.linspace(0, h - 1, h).to(device) # the default type of Tensor is float32, leading to type mismatch
# in fp16 mode. Cast it to support fp16 mode.
h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
h_idxs = h_idxs.view((h, 1)) * q_stride h_idxs = h_idxs.view((h, 1)) * q_stride
w_idxs = torch.linspace(0, w - 1, w).to(device) w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
w_idxs = w_idxs.view((w, 1)) * q_stride w_idxs = w_idxs.view((w, 1)) * q_stride
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(device) h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
device=device, dtype=dtype)
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(device) w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
device=device, dtype=dtype)
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
# (h, h_kv, 1) # (h, h_kv, 1)
...@@ -192,9 +197,10 @@ class GeneralizedAttention(nn.Module): ...@@ -192,9 +197,10 @@ class GeneralizedAttention(nn.Module):
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0) w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
w_diff *= self.position_magnitude w_diff *= self.position_magnitude
feat_range = torch.arange(0, feat_dim / 4).to(device) feat_range = torch.arange(0, feat_dim / 4).to(
device=device, dtype=dtype)
dim_mat = torch.Tensor([wave_length]).to(device) dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
dim_mat = dim_mat**((4. / feat_dim) * feat_range) dim_mat = dim_mat**((4. / feat_dim) * feat_range)
dim_mat = dim_mat.view((1, 1, -1)) dim_mat = dim_mat.view((1, 1, -1))
...@@ -234,7 +240,7 @@ class GeneralizedAttention(nn.Module): ...@@ -234,7 +240,7 @@ class GeneralizedAttention(nn.Module):
if self.attention_type[1] or self.attention_type[3]: if self.attention_type[1] or self.attention_type[3]:
position_embed_x, position_embed_y = self.get_position_embedding( position_embed_x, position_embed_y = self.get_position_embedding(
h, w, h_kv, w_kv, self.q_stride, self.kv_stride, h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
x_input.device, self.position_embedding_dim) x_input.device, x_input.dtype, self.position_embedding_dim)
# (n, num_heads, w, w_kv, dim) # (n, num_heads, w, w_kv, dim)
position_feat_x = self.appr_geom_fc_x(position_embed_x).\ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
view(1, w, w_kv, num_heads, self.qk_embed_dim).\ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
......
...@@ -60,3 +60,16 @@ def test_context_block(): ...@@ -60,3 +60,16 @@ def test_context_block():
assert gen_attention_block.kv_downsample is not None assert gen_attention_block.kv_downsample is not None
out = gen_attention_block(imgs) out = gen_attention_block(imgs)
assert out.shape == imgs.shape assert out.shape == imgs.shape
# test fp16 with attention_type='1111'
if torch.cuda.is_available():
imgs = torch.randn(2, 16, 20, 20).cuda().to(torch.half)
gen_attention_block = GeneralizedAttention(
16,
spatial_range=-1,
num_heads=8,
attention_type='1111',
kv_stride=2)
gen_attention_block.cuda().type(torch.half)
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment