test_init_on_device.py 904 Bytes
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import torch
import pytest
from unit.simple_model import SimpleModel
from deepspeed import OnDevice
from packaging import version as pkg_version
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest


@pytest.mark.parametrize('device', ['meta', get_accelerator().device_name(0)])
class TestOnDevice(DistributedTest):
    world_size = 1

    def test_on_device(self, device):
aiss's avatar
aiss committed
20
        if device == "meta" and pkg_version.parse(torch.__version__) < pkg_version.parse("1.10"):
aiss's avatar
aiss committed
21
22
23
24
25
26
27
28
            pytest.skip("meta tensors only became stable after torch 1.10")

        with OnDevice(dtype=torch.half, device=device):
            model = SimpleModel(4)

        for p in model.parameters():
            assert p.device == torch.device(device)
            assert p.dtype == torch.half