Unverified Commit ea3e9789 authored by pc's avatar pc Committed by GitHub
Browse files

fix parrots op bug (#1289)

parent f022d577
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import torch
def is_custom_op_loaded(): def is_custom_op_loaded():
flag = False flag = False
...@@ -16,4 +18,4 @@ def is_custom_op_loaded(): ...@@ -16,4 +18,4 @@ def is_custom_op_loaded():
flag = os.path.exists(ort_lib_path) flag = os.path.exists(ort_lib_path)
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
pass pass
return flag return flag or torch.__version__ == 'parrots'
...@@ -45,7 +45,7 @@ void sync_bn_forward_output_cuda_parrots(CudaContext& ctx, ...@@ -45,7 +45,7 @@ void sync_bn_forward_output_cuda_parrots(CudaContext& ctx,
auto running_var = buildATensor(ctx, outs[1]); auto running_var = buildATensor(ctx, outs[1]);
auto norm = buildATensor(ctx, outs[2]); auto norm = buildATensor(ctx, outs[2]);
auto std = buildATensor(ctx, outs[3]); auto std = buildATensor(ctx, outs[3]);
auto output = buildATensor(ctx, outs[3]); auto output = buildATensor(ctx, outs[4]);
sync_bn_forward_output_cuda(input, mean, var, running_mean, running_var, sync_bn_forward_output_cuda(input, mean, var, running_mean, running_var,
weight, bias, norm, std, output, eps, momentum, weight, bias, norm, std, output, eps, momentum,
group_size); group_size);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment