"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "1a665a63b09a83ab06317f8acfe7e7f75037c5ab"
Commit bc76b018 authored by Jie's avatar Jie
Browse files

update SyncBatchNorm API for primitive implementation to support apex.parallel.convert_syncbn_model

parent 8e8dd35d
...@@ -55,7 +55,7 @@ class SyncBatchNorm(_BatchNorm): ...@@ -55,7 +55,7 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda() >>> inp = torch.randn(10, 14, 14, 100).cuda()
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last = False): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group self.process_group = process_group
self.channel_last = channel_last self.channel_last = channel_last
......
...@@ -48,7 +48,9 @@ class SyncBatchNorm(_BatchNorm): ...@@ -48,7 +48,9 @@ class SyncBatchNorm(_BatchNorm):
warned = False warned = False
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
if channel_last == True:
raise AttributeError("channel_last is not supported by primitive SyncBatchNorm implementation. Try install apex with `--cuda_ext` if channel_last is desired.")
if not SyncBatchNorm.warned: if not SyncBatchNorm.warned:
print("Warning: using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext. The exception raised when attempting to import the cuda backend was: ", self.syncbn_import_error) print("Warning: using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext. The exception raised when attempting to import the cuda backend was: ", self.syncbn_import_error)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment