Unverified Commit 841a078e authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[Fix]: fix data type in fused-bias-leakyrelu for apex fp16 training (#981)

parent 9649a9ad
...@@ -45,10 +45,9 @@ class FusedBiasLeakyReLUFunctionBackward(Function): ...@@ -45,10 +45,9 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
# The second order deviation, in fact, contains two parts, while the # The second order deviation, in fact, contains two parts, while the
# the first part is zero. Thus, we direct consider the second part # the first part is zero. Thus, we direct consider the second part
# which is similar with the first order deviation in implementation. # which is similar with the first order deviation in implementation.
gradgrad_out = ext_module.fused_bias_leakyrelu(gradgrad_input, gradgrad_out = ext_module.fused_bias_leakyrelu(
gradgrad_bias, out, 3, gradgrad_input, gradgrad_bias.to(out.dtype), out, 3, 1,
1, ctx.negative_slope, ctx.negative_slope, ctx.scale)
ctx.scale)
return gradgrad_out, None, None, None return gradgrad_out, None, None, None
...@@ -139,7 +138,8 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5): ...@@ -139,7 +138,8 @@ def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
if not input.is_cuda: if not input.is_cuda:
return bias_leakyrelu_ref(input, bias, negative_slope, scale) return bias_leakyrelu_ref(input, bias, negative_slope, scale)
return FusedBiasLeakyReLUFunction.apply(input, bias, negative_slope, scale) return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
negative_slope, scale)
def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5): def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment