Unverified Commit 961373ad authored by Yanhong Zeng's avatar Yanhong Zeng Committed by GitHub
Browse files

[Fix] cast the type of mask to enable training with amp (#2220)

parent c3835415
...@@ -69,6 +69,7 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -69,6 +69,7 @@ class ModulatedDeformConv2dFunction(Function):
input = input.type_as(offset) input = input.type_as(offset)
weight = weight.type_as(input) weight = weight.type_as(input)
bias = bias.type_as(input) # type: ignore bias = bias.type_as(input) # type: ignore
mask = mask.type_as(input)
ctx.save_for_backward(input, offset, mask, weight, bias) ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty( output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment