"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1ddf3f3a19095344166ad7207ebc5be7a862d17e"
Unverified Commit a5db5f66 authored by wangjiangben-hw's avatar wangjiangben-hw Committed by GitHub
Browse files

[Feature] Support training on NPU device (#2262)



* init npu

* add npu extension and focal loss adapter

* clean code

* clean code

* clean code

* clean code

* fix autocast bugs on npu (#2273)

fix autocast bugs on npu (#2273)

* code format

* code format

* code format

* bug fix

* pytorch_npu_helper.hpp clean code

* Npu dev (#2306)

* fix autocast bugs on npu
* using scatter_kwargs in mmcv.device.scatter_gather

* raise ImportError when compile with npu

* add npu test case (#2307)

* add npu test case

* Update focal_loss.py

* add comment

* clean lint

* update dtype assert

* update DDP forward and comment

* fix bug
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarckirchhoff <515629648@qq.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 92504176
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from . import ipu, mlu, mps from . import ipu, mlu, mps, npu
from .scatter_gather import scatter, scatter_kwargs from .scatter_gather import scatter, scatter_kwargs
from .utils import get_device from .utils import get_device
__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs'] __all__ = [
'npu', 'mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs'
]
# Copyright Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import NPUDataParallel
from .distributed import NPUDistributedDataParallel
__all__ = ['NPUDataParallel', 'NPUDistributedDataParallel']
# Copyright Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import sys
import torch
from mmcv.device.scatter_gather import scatter_kwargs
from mmcv.parallel import MMDataParallel
def _check_balance(*args, **kwargs):
return
# Since we do not have a similar hardware unit multi_processor
# on the NPU, the corresponding# devices_properties does not
# have this property and cannot be checked. So we masked the
# _check_balance function in DataParallel to make initialization pass.
for m in sys.modules:
if m.startswith('torch') or 'mmcv' in m:
if hasattr(sys.modules[m], '_check_balance'):
setattr(sys.modules[m], '_check_balance', _check_balance)
class NPUDataParallel(MMDataParallel):
"""The NPUDataParallel module that supports DataContainer.
NPUDataParallel is a class inherited from MMDataParall, which supports
NPU training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of NPU, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
.. warning::
NPUDataParallel only supports single NPU training, if you need to
train with multiple NPUs, please use NPUDistributedDataParallel
instead. If you have multiple NPUs, you can toggle device_ids
parameters passed in for this function to specify the running device.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
device_id = kwargs.get('device_ids', [0])[0]
self.device_ids = [device_id]
self.src_device_obj = torch.device(f'npu:{device_id}')
torch.npu.set_device(self.src_device_obj)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
# Copyright Huawei Technologies Co., Ltd. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.device.scatter_gather import scatter_kwargs
from mmcv.parallel import MMDistributedDataParallel
class NPUDistributedDataParallel(MMDistributedDataParallel):
"""The DDP module supports DataContainer.
NPUDDP has one difference from MMDDP which moves data to NPU with coping
instead of scattering.
"""
def to_kwargs(self, inputs, kwargs, device_id):
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
# Since the scatter method is not supported on the NPU
# and the DDP class is rewritten, when the forward of DDP
# is used, the NPU will mask the scatter branch,
# resulting in the input not being placed on the device side.
# So, forward has been rewritten here primarily to circumvent
# this situation that would cause the device misalignment.
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return super().forward(*inputs[0], **kwargs[0])
return super().forward(*inputs, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE,
IS_NPU_AVAILABLE)
def get_device() -> str: def get_device() -> str:
"""Returns the currently existing device type. """Returns the currently existing device type.
.. note::
Since npu provides tools to automatically convert cuda functions,
we need to make judgments on npu first to avoid entering
the cuda branch when using npu.
Returns: Returns:
str: cuda | mlu | mps | cpu. str: cuda | mlu | mps | cpu.
""" """
if IS_CUDA_AVAILABLE: if IS_NPU_AVAILABLE:
return 'npu'
elif IS_CUDA_AVAILABLE:
return 'cuda' return 'cuda'
elif IS_MLU_AVAILABLE: elif IS_MLU_AVAILABLE:
return 'mlu' return 'mlu'
......
/******************************************************************************
* Copyright (c) 2022 Huawei Technologies Co., Ltd
* All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
#ifndef PYTORCH_NPU_HELPER_HPP_
#define PYTORCH_NPU_HELPER_HPP_
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
#include <torch_npu/csrc/framework/utils/OpAdapter.h>
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#define NPU_NAME_SPACE at_npu::native
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")
#endif // PYTORCH_NPU_HELPER_HPP_
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
at::Tensor target_y = at::reshape(target, input.sizes());
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
OpCommand cmd;
cmd.Name("SigmoidFocalLoss")
.Input(input)
.Input(target_y)
.Input(weight_y)
.Output(output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);
void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma,
float alpha) {
at::Tensor target_y = at::reshape(target, input.sizes());
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
OpCommand cmd;
cmd.Name("SigmoidFocalLossGrad")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}
void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor grad_input,
float gamma, float alpha);
void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y =
at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
OpCommand cmd;
cmd.Name("SoftmaxFocalLoss")
.Input(input)
.Input(target_y)
.Input(weight_y)
.Output(output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma,
float alpha);
void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input,
float gamma, float alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y =
at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
OpCommand cmd;
cmd.Name("SoftmaxFocalLossGrad")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Run();
}
void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha);
REGISTER_NPU_IMPL(sigmoid_focal_loss_forward_impl,
sigmoid_focal_loss_forward_npu);
REGISTER_NPU_IMPL(sigmoid_focal_loss_backward_impl,
sigmoid_focal_loss_backward_npu);
REGISTER_NPU_IMPL(softmax_focal_loss_forward_impl,
softmax_focal_loss_forward_npu);
REGISTER_NPU_IMPL(softmax_focal_loss_backward_impl,
softmax_focal_loss_backward_npu);
...@@ -38,8 +38,7 @@ class SigmoidFocalLossFunction(Function): ...@@ -38,8 +38,7 @@ class SigmoidFocalLossFunction(Function):
weight: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None,
reduction: str = 'mean') -> torch.Tensor: reduction: str = 'mean') -> torch.Tensor:
assert isinstance( assert target.dtype == torch.long
target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor))
assert input.dim() == 2 assert input.dim() == 2
assert target.dim() == 1 assert target.dim() == 1
assert input.size(0) == target.size(0) assert input.size(0) == target.size(0)
...@@ -143,7 +142,7 @@ class SoftmaxFocalLossFunction(Function): ...@@ -143,7 +142,7 @@ class SoftmaxFocalLossFunction(Function):
weight: Optional[torch.Tensor] = None, weight: Optional[torch.Tensor] = None,
reduction='mean') -> torch.Tensor: reduction='mean') -> torch.Tensor:
assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) assert target.dtype == torch.long
assert input.dim() == 2 assert input.dim() == 2
assert target.dim() == 1 assert target.dim() == 1
assert input.size(0) == target.size(0) assert input.size(0) == target.size(0)
......
...@@ -13,7 +13,7 @@ from torch import distributed as dist ...@@ -13,7 +13,7 @@ from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors, from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) _unflatten_dense_tensors)
from mmcv.utils import IS_MLU_AVAILABLE from mmcv.utils import IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
def _find_free_port() -> str: def _find_free_port() -> str:
...@@ -58,6 +58,14 @@ def _init_dist_pytorch(backend: str, **kwargs) -> None: ...@@ -58,6 +58,14 @@ def _init_dist_pytorch(backend: str, **kwargs) -> None:
rank=rank, rank=rank,
world_size=int(os.environ['WORLD_SIZE']), world_size=int(os.environ['WORLD_SIZE']),
**kwargs) **kwargs)
elif IS_NPU_AVAILABLE:
import torch_npu # noqa: F401
torch.npu.set_device(rank)
dist.init_process_group(
backend='hccl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else: else:
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus) torch.cuda.set_device(rank % num_gpus)
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import IS_NPU_AVAILABLE, TORCH_VERSION, digit_version
from .dist_utils import allreduce_grads as _allreduce_grads from .dist_utils import allreduce_grads as _allreduce_grads
try: try:
...@@ -18,7 +18,10 @@ try: ...@@ -18,7 +18,10 @@ try:
# and used; otherwise, auto fp16 will adopt mmcv's implementation. # and used; otherwise, auto fp16 will adopt mmcv's implementation.
# Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16 # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
# manually, so the behavior may not be consistent with real amp. # manually, so the behavior may not be consistent with real amp.
from torch.cuda.amp import autocast if IS_NPU_AVAILABLE:
from torch.npu.amp import autocast
else:
from torch.cuda.amp import autocast
except ImportError: except ImportError:
pass pass
......
...@@ -9,7 +9,8 @@ import torch.nn as nn ...@@ -9,7 +9,8 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version from mmcv.utils import (IS_NPU_AVAILABLE, TORCH_VERSION, _BatchNorm,
digit_version)
from ..dist_utils import allreduce_grads from ..dist_utils import allreduce_grads
from ..fp16_utils import LossScaler, wrap_fp16_model from ..fp16_utils import LossScaler, wrap_fp16_model
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -17,7 +18,10 @@ from .hook import HOOKS, Hook ...@@ -17,7 +18,10 @@ from .hook import HOOKS, Hook
try: try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation. # and used; otherwise, auto fp16 will adopt mmcv's implementation.
from torch.cuda.amp import GradScaler if IS_NPU_AVAILABLE:
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
except ImportError: except ImportError:
pass pass
......
...@@ -37,7 +37,7 @@ except ImportError: ...@@ -37,7 +37,7 @@ except ImportError:
] ]
else: else:
from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE, from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
IS_MPS_AVAILABLE) IS_MPS_AVAILABLE, IS_NPU_AVAILABLE)
from .env import collect_env from .env import collect_env
from .hub import load_url from .hub import load_url
from .logging import get_logger, print_log from .logging import get_logger, print_log
...@@ -77,5 +77,5 @@ else: ...@@ -77,5 +77,5 @@ else:
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE', '_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE', 'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
'IS_MPS_AVAILABLE', 'torch_meshgrid' 'IS_MPS_AVAILABLE', 'IS_NPU_AVAILABLE', 'torch_meshgrid'
] ]
...@@ -38,3 +38,16 @@ def is_mps_available() -> bool: ...@@ -38,3 +38,16 @@ def is_mps_available() -> bool:
IS_MPS_AVAILABLE = is_mps_available() IS_MPS_AVAILABLE = is_mps_available()
def is_npu_available() -> bool:
"""Return True if npu devices exist."""
try:
import torch
import torch_npu
return (hasattr(torch, 'npu') and torch_npu.npu.is_available())
except Exception:
return False
IS_NPU_AVAILABLE = is_npu_available()
...@@ -330,6 +330,21 @@ def get_extensions(): ...@@ -330,6 +330,21 @@ def get_extensions():
extension = CppExtension extension = CppExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mps'))
elif (os.getenv('FORCE_NPU', '0') == '1'):
print(f'Compiling {ext_name} only with CPU and NPU')
try:
from torch_npu.utils.cpp_extension import NpuExtension
define_macros += [('MMCV_WITH_NPU', None)]
extension = NpuExtension
except Exception:
raise ImportError('can not find any torch_npu')
# src
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/common/npu/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/npu/*.cpp')
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/npu'))
else: else:
print(f'Compiling {ext_name} only with CPU') print(f'Compiling {ext_name} only with CPU')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.device import get_device from mmcv.device import get_device
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE,
IS_NPU_AVAILABLE)
def test_get_device(): def test_get_device():
current_device = get_device() current_device = get_device()
if IS_CUDA_AVAILABLE: if IS_NPU_AVAILABLE:
assert current_device == 'npu'
elif IS_CUDA_AVAILABLE:
assert current_device == 'cuda' assert current_device == 'cuda'
elif IS_MLU_AVAILABLE: elif IS_MLU_AVAILABLE:
assert current_device == 'mlu' assert current_device == 'mlu'
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import torch import torch
from mmcv.device._functions import Scatter, scatter from mmcv.device._functions import Scatter, scatter
from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE, IS_NPU_AVAILABLE
def test_scatter(): def test_scatter():
...@@ -28,6 +28,17 @@ def test_scatter(): ...@@ -28,6 +28,17 @@ def test_scatter():
for input, output in zip(inputs, outputs): for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output) assert torch.allclose(input.to('mlu'), output)
# if the device is NPU, copy the input from CPU to NPU
if IS_NPU_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('npu'), output)
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('npu'), output)
# if the device is MPS, copy the input from CPU to MPS # if the device is MPS, copy the input from CPU to MPS
if IS_MPS_AVAILABLE: if IS_MPS_AVAILABLE:
input = torch.zeros([1, 3, 3, 3]) input = torch.zeros([1, 3, 3, 3])
......
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch
import torch.nn as nn
from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel
from mmcv.parallel import is_module_wrapper
from mmcv.utils import IS_NPU_AVAILABLE
def mock(*args, **kwargs):
pass
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
assert not is_module_wrapper(model)
if IS_NPU_AVAILABLE:
npudp = NPUDataParallel(model)
assert is_module_wrapper(npudp)
npuddp = NPUDistributedDataParallel(model, process_group=MagicMock())
assert is_module_wrapper(npuddp)
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
_USING_PARROTS = True _USING_PARROTS = True
try: try:
...@@ -130,6 +130,10 @@ class Testfocalloss: ...@@ -130,6 +130,10 @@ class Testfocalloss:
self._test_softmax(dtype=torch.half) self._test_softmax(dtype=torch.half)
@pytest.mark.parametrize('device', [ @pytest.mark.parametrize('device', [
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param( pytest.param(
'cuda', 'cuda',
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
...@@ -143,6 +147,10 @@ class Testfocalloss: ...@@ -143,6 +147,10 @@ class Testfocalloss:
self._test_sigmoid(device=device, dtype=torch.float) self._test_sigmoid(device=device, dtype=torch.float)
@pytest.mark.parametrize('device', [ @pytest.mark.parametrize('device', [
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support')),
pytest.param( pytest.param(
'cuda', 'cuda',
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment