real_accelerator.py 3.58 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

try:
    from accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa1
except ImportError as e:
    dsa1 = None
try:
    from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator as dsa2
except ImportError as e:
    dsa2 = None

ds_accelerator = None


def _validate_accelerator(accel_obj):
    # because abstract_accelerator has different path during
    # build time (accelerator.abstract_accelerator)
    # and run time (deepspeed.accelerator.abstract_accelerator)
    # and extension would import the
    # run time abstract_accelerator/DeepSpeedAccelerator as its base
    # class, so we need to compare accel_obj with both base class.
    # if accel_obj is instance of DeepSpeedAccelerator in one of
    # accelerator.abstractor_accelerator
    # or deepspeed.accelerator.abstract_accelerator, consider accel_obj
    # is a conforming object
aiss's avatar
aiss committed
29
30
    if not ((dsa1 != None and isinstance(accel_obj, dsa1)) or (dsa2 != None and isinstance(accel_obj, dsa2))):
        raise AssertionError(f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator')
aiss's avatar
aiss committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    # TODO: turn off is_available test since this breaks tests
    #assert accel_obj.is_available(), \
    #    f'{accel_obj.__class__.__name__} accelerator fails is_available() test'


def get_accelerator():
    global ds_accelerator
    if ds_accelerator is None:
        try:
            from intel_extension_for_deepspeed import XPU_Accelerator
        except ImportError as e:
            pass
        else:
            ds_accelerator = XPU_Accelerator()
            _validate_accelerator(ds_accelerator)
            return ds_accelerator

        from .cuda_accelerator import CUDA_Accelerator
        ds_accelerator = CUDA_Accelerator()
        _validate_accelerator(ds_accelerator)
    return ds_accelerator


def set_accelerator(accel_obj):
    global ds_accelerator
    _validate_accelerator(accel_obj)
    ds_accelerator = accel_obj


'''
-----------[code] test_get.py -----------
from deepspeed.accelerator import get_accelerator
my_accelerator = get_accelerator()
print(f'{my_accelerator._name=}')
print(f'{my_accelerator._communication_backend=}')
print(f'{my_accelerator.HalfTensor().device=}')
print(f'{my_accelerator.total_memory()=}')
-----------[code] test_get.py -----------

---[output] python test_get.py---------
my_accelerator.name()='cuda'
my_accelerator.communication_backend='nccl'
my_accelerator.HalfTensor().device=device(type='cuda', index=0)
my_accelerator.total_memory()=34089730048
---[output] python test_get.py---------

**************************************************************************
-----------[code] test_set.py -----------
from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator
cu_accel = CUDA_Accelerator()
print(f'{id(cu_accel)=}')
from deepspeed.accelerator import set_accelerator, get_accelerator
set_accelerator(cu_accel)

my_accelerator = get_accelerator()
print(f'{id(my_accelerator)=}')
print(f'{my_accelerator._name=}')
print(f'{my_accelerator._communication_backend=}')
print(f'{my_accelerator.HalfTensor().device=}')
print(f'{my_accelerator.total_memory()=}')
-----------[code] test_set.py -----------


---[output] python test_set.py---------
id(cu_accel)=139648165478304
my_accelerator=<deepspeed.accelerator.cuda_accelerator.CUDA_Accelerator object at 0x7f025f4bffa0>
my_accelerator.name='cuda'
my_accelerator.communication_backend='nccl'
my_accelerator.HalfTensor().device=device(type='cuda', index=0)
my_accelerator.total_memory()=34089730048
---[output] python test_set.py---------
'''