Unverified Commit 466a6c82 authored by Xiangxu-0103's avatar Xiangxu-0103 Committed by GitHub
Browse files

[Fix] fix deconv_flops_counter_hooker (#1760)

* fix deconv_flops_counter

* Update generalized_attention.py
parent ea64b512
...@@ -351,7 +351,7 @@ class GeneralizedAttention(nn.Module): ...@@ -351,7 +351,7 @@ class GeneralizedAttention(nn.Module):
repeat(n, 1, 1, 1) repeat(n, 1, 1, 1)
position_feat_x_reshape = position_feat_x.\ position_feat_x_reshape = position_feat_x.\
view(n, num_heads, w*w_kv, self.qk_embed_dim) view(n, num_heads, w * w_kv, self.qk_embed_dim)
position_feat_y_reshape = position_feat_y.\ position_feat_y_reshape = position_feat_y.\
view(n, num_heads, h * h_kv, self.qk_embed_dim) view(n, num_heads, h * h_kv, self.qk_embed_dim)
......
...@@ -459,7 +459,7 @@ def deconv_flops_counter_hook(conv_module, input, output): ...@@ -459,7 +459,7 @@ def deconv_flops_counter_hook(conv_module, input, output):
bias_flops = 0 bias_flops = 0
if conv_module.bias is not None: if conv_module.bias is not None:
output_height, output_width = output.shape[2:] output_height, output_width = output.shape[2:]
bias_flops = out_channels * batch_size * output_height * output_height bias_flops = out_channels * batch_size * output_height * output_width
overall_flops = overall_conv_flops + bias_flops overall_flops = overall_conv_flops + bias_flops
conv_module.__flops__ += int(overall_flops) conv_module.__flops__ += int(overall_flops)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment