"...text-generation-inference.git" did not exist on "85aa7e2e7b02608eea04206b6cc0fa0ccced80ef"
Unverified Commit ddb52228 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #28 from open-mmlab/pytorch-1.0

Add compatibility with PyTorch 0.4.1 and 1.0
parents fd009613 477010fb
...@@ -54,7 +54,7 @@ def imfrombytes(content, flag='color'): ...@@ -54,7 +54,7 @@ def imfrombytes(content, flag='color'):
Returns: Returns:
ndarray: Loaded image array. ndarray: Loaded image array.
""" """
img_np = np.fromstring(content, np.uint8) img_np = np.frombuffer(content, np.uint8)
flag = imread_flags[flag] if is_str(flag) else flag flag = imread_flags[flag] if is_str(flag) else flag
img = cv2.imdecode(img_np, flag) img = cv2.imdecode(img_np, flag)
return img return img
......
...@@ -33,7 +33,10 @@ class MMDistributedDataParallel(nn.Module): ...@@ -33,7 +33,10 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(module_states, self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size) self.broadcast_bucket_size)
if self.broadcast_buffers: if self.broadcast_buffers:
buffers = [b.data for b in self.module._all_buffers()] if torch.__version__ < '1.0':
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
if len(buffers) > 0: if len(buffers) > 0:
self._dist_broadcast_coalesced(buffers, self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size) self.broadcast_bucket_size)
......
...@@ -5,6 +5,7 @@ from getpass import getuser ...@@ -5,6 +5,7 @@ from getpass import getuser
from socket import gethostname from socket import gethostname
import mmcv import mmcv
import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -13,7 +14,11 @@ def get_host_info(): ...@@ -13,7 +14,11 @@ def get_host_info():
def get_dist_info(): def get_dist_info():
if dist._initialized: if torch.__version__ < '1.0':
initialized = dist._initialized
else:
initialized = dist.is_initialized()
if initialized:
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
else: else:
......
__version__ = '0.2.2' __version__ = '0.2.3'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment