Unverified Commit 6f6b17e6 authored by mattcasey02's avatar mattcasey02 Committed by GitHub
Browse files

[Fix] Fixed version comparison to include prerelease versions (#1877)

* Fixed version comparison to include prerelease versions

Currently all tagged versions of torch 1.11.0 have version 1.11.0a0. Previously the comparison to 1.11.0 failed and self._sync_params() was still used, causing an error. This fix should include all versions of 1.11.

* Same update

Didn't realize that 1.11.0 was mentioned multiple times in the file. This fixes the other instances.
parent 5221a388
...@@ -45,7 +45,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -45,7 +45,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
logger='mmcv') logger='mmcv')
if ('parrots' not in TORCH_VERSION if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd(): if self._check_sync_bufs_pre_fwd():
self._sync_buffers() self._sync_buffers()
else: else:
...@@ -65,7 +65,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -65,7 +65,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
output = self.module.train_step(*inputs, **kwargs) output = self.module.train_step(*inputs, **kwargs)
if ('parrots' not in TORCH_VERSION if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd(): if self._check_sync_bufs_post_fwd():
self._sync_buffers() self._sync_buffers()
...@@ -100,7 +100,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -100,7 +100,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
logger='mmcv') logger='mmcv')
if ('parrots' not in TORCH_VERSION if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd(): if self._check_sync_bufs_pre_fwd():
self._sync_buffers() self._sync_buffers()
else: else:
...@@ -120,7 +120,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -120,7 +120,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
output = self.module.val_step(*inputs, **kwargs) output = self.module.val_step(*inputs, **kwargs)
if ('parrots' not in TORCH_VERSION if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0')): and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd(): if self._check_sync_bufs_post_fwd():
self._sync_buffers() self._sync_buffers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment