Unverified Commit 47962467 authored by Yosuke Shinya's avatar Yosuke Shinya Committed by GitHub
Browse files

[Fix] Fix DCN fp16 (#1014)

* [Fix] Fix DCN fp16

* add comment
parent f61295d9
...@@ -70,6 +70,9 @@ class DeformConv2dFunction(Function): ...@@ -70,6 +70,9 @@ class DeformConv2dFunction(Function):
ctx.deform_groups = deform_groups ctx.deform_groups = deform_groups
ctx.im2col_step = im2col_step ctx.im2col_step = im2col_step
# until the code is modified for torch.cuda.amp.autocast,
# we need to cast weight to avoid type mismatch in fp16 training
weight = weight.type_as(input)
ctx.save_for_backward(input, offset, weight) ctx.save_for_backward(input, offset, weight)
output = input.new_empty( output = input.new_empty(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment