Unverified Commit 6a03918f authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Feature] Add support for mps (#2092)

* [Feature] Add support for MPS

* fix import error

* update ut

* fix error

* trigger CI

* use a unique basename for test file modules

* avoid bc-breaking
parent 357b484d
...@@ -61,7 +61,7 @@ jobs: ...@@ -61,7 +61,7 @@ jobs:
--ignore=tests/test_utils/test_parrots_jit.py \ --ignore=tests/test_utils/test_parrots_jit.py \
--ignore=tests/test_utils/test_trace.py \ --ignore=tests/test_utils/test_trace.py \
--ignore=tests/test_utils/test_hub.py \ --ignore=tests/test_utils/test_hub.py \
--ignore=tests/test_device/test_mlu/test_mlu_parallel.py \ --ignore=tests/test_device \
--ignore=tests/test_utils/test_torch_ops.py --ignore=tests/test_utils/test_torch_ops.py
build_without_ops: build_without_ops:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from . import ipu, mlu from . import ipu, mlu, mps
from .scatter_gather import scatter, scatter_kwargs
from .utils import get_device
__all__ = ['mlu', 'ipu'] __all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import torch
from mmcv.utils import deprecated_api_warning
from .utils import get_device
def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
"""scatter copies tensor to devices directly."""
current_device = get_device()
if isinstance(input, list):
outputs = [scatter(_input, devices) for _input in input]
return outputs
elif isinstance(input, torch.Tensor):
output = input.contiguous()
return output.to(current_device) if devices != [-1] else output
else:
raise Exception(f'Unknown type {type(input)}.')
class Scatter:
@staticmethod
@deprecated_api_warning({'target_mlus': 'target_devices'},
cls_name='Scatter')
def forward(target_devices, input):
outputs = scatter(input, target_devices)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MLUDataParallel from .data_parallel import MLUDataParallel
from .distributed import MLUDistributedDataParallel from .distributed import MLUDistributedDataParallel
from .scatter_gather import scatter, scatter_kwargs
__all__ = [ __all__ = ['MLUDataParallel', 'MLUDistributedDataParallel']
'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter',
'scatter_kwargs'
]
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MPSDataParallel
__all__ = ['MPSDataParallel']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel
from ..scatter_gather import scatter_kwargs
class MPSDataParallel(MMDataParallel):
"""The MPSDataParallel module that supports DataContainer.
MPSDataParallel is a class inherited from MMDataParall, which supports
MPS training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MPS, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mps:0')
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel.data_container import DataContainer
from mmcv.utils import deprecated_api_warning
from ._functions import Scatter
from .utils import get_device
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter(inputs, target_devices, dim=0):
"""Scatter inputs to target devices.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
current_device = get_device()
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_devices != [-1]:
obj = obj.to(current_device)
return [obj]
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_devices, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_devices, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for _ in target_devices]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None
@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_devices, dim) if inputs else []
kwargs = scatter(kwargs, target_devices, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | mps | cpu.
"""
if IS_CUDA_AVAILABLE:
return 'cuda'
elif IS_MLU_AVAILABLE:
return 'mlu'
elif IS_MPS_AVAILABLE:
return 'mps'
else:
return 'cpu'
...@@ -14,7 +14,7 @@ class MMDataParallel(DataParallel): ...@@ -14,7 +14,7 @@ class MMDataParallel(DataParallel):
- It supports a custom type :class:`DataContainer` which allows more - It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference. flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``. - It implements two more APIs ``train_step()`` and ``val_step()``.
.. warning:: .. warning::
MMDataParallel only supports single GPU training, if you need to MMDataParallel only supports single GPU training, if you need to
......
...@@ -36,7 +36,8 @@ except ImportError: ...@@ -36,7 +36,8 @@ except ImportError:
'is_method_overridden', 'has_method' 'is_method_overridden', 'has_method'
] ]
else: else:
from .device_type import IS_IPU_AVAILABLE, IS_MLU_AVAILABLE from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
IS_MPS_AVAILABLE)
from .env import collect_env from .env import collect_env
from .hub import load_url from .hub import load_url
from .logging import get_logger, print_log from .logging import get_logger, print_log
...@@ -76,5 +77,5 @@ else: ...@@ -76,5 +77,5 @@ else:
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE', '_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE', 'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
'torch_meshgrid' 'IS_MPS_AVAILABLE', 'torch_meshgrid'
] ]
...@@ -22,3 +22,19 @@ def is_mlu_available() -> bool: ...@@ -22,3 +22,19 @@ def is_mlu_available() -> bool:
IS_MLU_AVAILABLE = is_mlu_available() IS_MLU_AVAILABLE = is_mlu_available()
def is_mps_available() -> bool:
"""Return True if mps devices exist.
It's specialized for mac m1 chips and require torch version 1.12 or higher.
"""
try:
import torch
return hasattr(torch.backends,
'mps') and torch.backends.mps.is_available()
except Exception:
return False
IS_MPS_AVAILABLE = is_mps_available()
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.device import get_device
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def test_get_device():
current_device = get_device()
if IS_CUDA_AVAILABLE:
assert current_device == 'cuda'
elif IS_MLU_AVAILABLE:
assert current_device == 'mlu'
elif IS_MPS_AVAILABLE:
assert current_device == 'mps'
else:
assert current_device == 'cpu'
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.device._functions import Scatter, scatter
from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mlu'), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output)
# if the device is MPS, copy the input from CPU to MPS
if IS_MPS_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mps'), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mps'), output)
# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])
def test_Scatter():
# if the device is CPU, just return the input
target_devices = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])
target_devices = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
target_devices = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mlu'), outputs[0])
target_devices = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output[0])
# if the device is MPS, copy the input from CPU to MPS
if IS_MPS_AVAILABLE:
target_devices = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mps'), outputs[0])
target_devices = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mps'), output[0])
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn import torch.nn as nn
from mmcv.device.mlu import MLUDataParallel, MLUDistributedDataParallel from mmcv.device.mlu import MLUDataParallel, MLUDistributedDataParallel
from mmcv.device.mlu._functions import Scatter, scatter
from mmcv.parallel import is_module_wrapper from mmcv.parallel import is_module_wrapper
from mmcv.utils import IS_MLU_AVAILABLE from mmcv.utils import IS_MLU_AVAILABLE
...@@ -38,61 +35,3 @@ def test_is_module_wrapper(): ...@@ -38,61 +35,3 @@ def test_is_module_wrapper():
mluddp = MLUDistributedDataParallel(model, process_group=MagicMock()) mluddp = MLUDistributedDataParallel(model, process_group=MagicMock())
assert is_module_wrapper(mluddp) assert is_module_wrapper(mluddp)
def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mlu'), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output)
# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])
def test_Scatter():
# if the device is CPU, just return the input
target_mlus = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_mlus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])
target_mlus = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_mlus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)
# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
target_mlus = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_mlus, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mlu'), outputs[0])
target_mlus = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_mlus, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output[0])
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
import torch.nn as nn
from mmcv.device.mps import MPSDataParallel
from mmcv.parallel import is_module_wrapper
from mmcv.utils import IS_MPS_AVAILABLE
def mock(*args, **kwargs):
pass
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
assert not is_module_wrapper(model)
if IS_MPS_AVAILABLE:
mpsdp = MPSDataParallel(model)
assert is_module_wrapper(mpsdp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment