Unverified Commit d202cc28 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] change device to accelerator api (#5239)



* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------
Co-authored-by: default avatarXuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: default avatarzxl <43881818+oahzxl@users.noreply.github.com>
parent dd2c28a3
...@@ -7,7 +7,7 @@ from torch import nn ...@@ -7,7 +7,7 @@ from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size from torch.distributed import ProcessGroup, get_world_size
from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state from colossalai.accelerator import get_accelerator
class SeqParallelUtils: class SeqParallelUtils:
...@@ -110,10 +110,10 @@ class Randomizer: ...@@ -110,10 +110,10 @@ class Randomizer:
# 1. get the current rng state # 1. get the current rng state
# 2. set the seed and store the rng state # 2. set the seed and store the rng state
# 3. recover the original rng state # 3. recover the original rng state
device_original_rng_state = get_rng_state() device_original_rng_state = get_accelerator().get_rng_state()
manual_seed(seed) get_accelerator().manual_seed(seed)
self.device_rng_state = get_rng_state() self.device_rng_state = get_accelerator().get_rng_state()
set_rng_state(device_original_rng_state) get_accelerator().set_rng_state(device_original_rng_state)
# to the same for cpu rng state # to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state() cpu_original_rng_state = torch.get_rng_state()
...@@ -122,10 +122,10 @@ class Randomizer: ...@@ -122,10 +122,10 @@ class Randomizer:
torch.set_rng_state(cpu_original_rng_state) torch.set_rng_state(cpu_original_rng_state)
def _set_device_rng_state(self, rng_state): def _set_device_rng_state(self, rng_state):
set_rng_state(rng_state) get_accelerator().set_rng_state(rng_state)
def _get_device_rng_state(self): def _get_device_rng_state(self):
current_state = get_rng_state() current_state = get_accelerator().get_rng_state()
return current_state return current_state
def _set_cpu_rng_state(self, rng_state): def _set_cpu_rng_state(self, rng_state):
...@@ -210,7 +210,7 @@ class Randomizer: ...@@ -210,7 +210,7 @@ class Randomizer:
index = Randomizer.index() index = Randomizer.index()
if dist.is_initialized(): if dist.is_initialized():
# convert the index to tensor # convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index # all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
...@@ -232,7 +232,7 @@ class Randomizer: ...@@ -232,7 +232,7 @@ class Randomizer:
if dist.is_initialized(): if dist.is_initialized():
# convert the index to tensor # convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index # all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
......
...@@ -9,7 +9,8 @@ from typing import Any, Callable, List ...@@ -9,7 +9,8 @@ from typing import Any, Callable, List
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from packaging import version from packaging import version
from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
from colossalai.accelerator import get_accelerator
def parameterize(argument: str, values: List[Any]) -> Callable: def parameterize(argument: str, values: List[Any]) -> Callable:
...@@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int): ...@@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
def _wrap_func(f): def _wrap_func(f):
def _execute_by_gpu_num(*args, **kwargs): def _execute_by_gpu_num(*args, **kwargs):
num_avail_gpu = device_count() num_avail_gpu = get_accelerator().device_count()
if num_avail_gpu >= min_gpus: if num_avail_gpu >= min_gpus:
f(*args, **kwargs) f(*args, **kwargs)
...@@ -263,11 +264,11 @@ def clear_cache_before_run(): ...@@ -263,11 +264,11 @@ def clear_cache_before_run():
def _wrap_func(f): def _wrap_func(f):
def _clear_cache(*args, **kwargs): def _clear_cache(*args, **kwargs):
empty_cache() get_accelerator().empty_cache()
reset_peak_memory_stats() get_accelerator().reset_peak_memory_stats()
reset_max_memory_allocated() get_accelerator().reset_max_memory_allocated()
reset_max_memory_cached() get_accelerator().reset_max_memory_cached()
synchronize() get_accelerator().synchronize()
gc.collect() gc.collect()
f(*args, **kwargs) f(*args, **kwargs)
......
...@@ -7,17 +7,12 @@ from .common import ( ...@@ -7,17 +7,12 @@ from .common import (
is_ddp_ignored, is_ddp_ignored,
set_seed, set_seed,
) )
from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
from .tensor_detector import TensorDetector from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
__all__ = [ __all__ = [
"conditional_context", "conditional_context",
"get_current_device",
"synchronize",
"empty_cache",
"set_to_cuda",
"Timer", "Timer",
"MultiTimer", "MultiTimer",
"multi_tensor_applier", "multi_tensor_applier",
...@@ -28,6 +23,4 @@ __all__ = [ ...@@ -28,6 +23,4 @@ __all__ = [
"free_storage", "free_storage",
"set_seed", "set_seed",
"is_ddp_ignored", "is_ddp_ignored",
"set_device",
"IS_NPU_AVAILABLE",
] ]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple, Callable
import torch
import torch.distributed as dist
IS_NPU_AVAILABLE: bool = False
try:
import torch_npu # noqa
IS_NPU_AVAILABLE = torch.npu.is_available()
except ImportError:
pass
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif IS_NPU_AVAILABLE:
return torch.device(f"npu:{torch.npu.current_device()}")
else:
return torch.device("cpu")
def _dispatch_device_func(fn_name: str, *args, **kwargs):
if torch.cuda.is_available():
return getattr(torch.cuda, fn_name)(*args, **kwargs)
elif IS_NPU_AVAILABLE:
return getattr(torch.npu, fn_name)(*args, **kwargs)
else:
raise RuntimeError("No device available")
# device semantics
def can_device_access_peer(device, peer_device) -> bool:
return _dispatch_device_func("can_device_access_peer", device, peer_device)
def current_device() -> int:
return _dispatch_device_func("current_device")
def current_stream(device=None):
return _dispatch_device_func("current_stream", device)
def default_stream(device=None):
return _dispatch_device_func("default_stream", device)
def device_count() -> int:
return _dispatch_device_func("device_count")
def get_device_capability(device=None) -> Tuple[int, int]:
return _dispatch_device_func("get_device_capability", device)
def get_device_name(device=None) -> str:
return _dispatch_device_func("get_device_name", device)
def get_device_properties(device):
return _dispatch_device_func("get_device_properties", device)
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % device_count()
_dispatch_device_func("set_device", index)
def set_stream(stream_):
return _dispatch_device_func("set_stream", stream_)
def stream(stream_):
return _dispatch_device_func("stream", stream_)
def synchronize():
return _dispatch_device_func("synchronize")
def utilization(device=None) -> int:
return _dispatch_device_func("utilization", device)
# random number generator
def get_rng_state(device="cuda") -> torch.Tensor:
return _dispatch_device_func("get_rng_state", device)
def get_rng_state_all() -> List[torch.Tensor]:
return _dispatch_device_func("get_rng_state_all")
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
return _dispatch_device_func("set_rng_state", new_state, device)
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
return _dispatch_device_func("set_rng_state_all", new_states)
def manual_seed(seed: int) -> None:
return _dispatch_device_func("manual_seed", seed)
def manual_seed_all(seed: int) -> None:
return _dispatch_device_func("manual_seed_all", seed)
def seed() -> None:
return _dispatch_device_func("seed")
def seed_all() -> None:
return _dispatch_device_func("seed_all")
def initial_seed() -> int:
return _dispatch_device_func("initial_seed")
# streams and events
def Stream(device=None, priority=0, **kwargs):
return _dispatch_device_func("Stream", device, priority, **kwargs)
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
# memory management
def empty_cache() -> None:
return _dispatch_device_func("empty_cache")
def memory_stats(device=None) -> Dict[str, Any]:
return _dispatch_device_func("memory_stats", device)
def memory_summary(device=None, abbreviated=False) -> str:
return _dispatch_device_func("memory_summary", device, abbreviated)
def memory_snapshot():
return _dispatch_device_func("memory_snapshot")
def memory_allocated(device=None) -> int:
return _dispatch_device_func("memory_allocated", device)
def max_memory_allocated(device=None) -> int:
return _dispatch_device_func("max_memory_allocated", device)
def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)
def reset_max_memory_cached(device=None) -> None:
return _dispatch_device_func("reset_max_memory_cached", device)
def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)
def max_memory_reserved(device=None) -> int:
return _dispatch_device_func("max_memory_reserved", device)
def set_per_process_memory_fraction(fraction: float, device=None) -> None:
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)
# amp
def autocast() -> Callable:
if torch.cuda.is_available():
return torch.cuda.amp.autocast()
elif IS_NPU_AVAILABLE:
return torch.npu.amp.autocast()
else:
raise RuntimeError("No device available")
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import time import time
from typing import Tuple from typing import Tuple
from .device import synchronize from colossalai.accelerator import get_accelerator
class Timer: class Timer:
...@@ -21,13 +21,13 @@ class Timer: ...@@ -21,13 +21,13 @@ class Timer:
@property @property
def current_time(self) -> float: def current_time(self) -> float:
synchronize() get_accelerator().synchronize()
return time.time() return time.time()
def start(self): def start(self):
"""Firstly synchronize cuda, reset the clock and then start the timer.""" """Firstly synchronize cuda, reset the clock and then start the timer."""
self._elapsed = 0 self._elapsed = 0
synchronize() get_accelerator().synchronize()
self._start_time = time.time() self._start_time = time.time()
self._started = True self._started = True
...@@ -44,7 +44,7 @@ class Timer: ...@@ -44,7 +44,7 @@ class Timer:
Returns: Returns:
int: Start-stop interval. int: Start-stop interval.
""" """
synchronize() get_accelerator().synchronize()
end_time = time.time() end_time = time.time()
elapsed = end_time - self._start_time elapsed = end_time - self._start_time
if keep_in_history: if keep_in_history:
......
...@@ -6,8 +6,7 @@ import torch ...@@ -6,8 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from colossalai.utils.device import IS_NPU_AVAILABLE
class TensorState(Enum): class TensorState(Enum):
...@@ -107,7 +106,7 @@ class Chunk: ...@@ -107,7 +106,7 @@ class Chunk:
self.valid_end = self.shard_size self.valid_end = self.shard_size
self.dtype = dtype self.dtype = dtype
device = init_device or get_current_device() device = init_device or get_accelerator().get_current_device()
# chunk_temp is a global chunk, which only exists during building the chunks. # chunk_temp is a global chunk, which only exists during building the chunks.
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
...@@ -125,7 +124,7 @@ class Chunk: ...@@ -125,7 +124,7 @@ class Chunk:
# configure the init device of the shard # configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA # no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU # offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size self.shard_mem = self.chunk_mem // self.pg_size
...@@ -192,10 +191,7 @@ class Chunk: ...@@ -192,10 +191,7 @@ class Chunk:
if self.chunk_temp is not None: if self.chunk_temp is not None:
return self.chunk_temp.device.type return self.chunk_temp.device.type
else: else:
if self.is_gathered or self.cuda_shard is not None: return get_accelerator().name
return "npu" if IS_NPU_AVAILABLE else "cuda"
else:
return "cpu"
@property @property
def payload(self) -> torch.Tensor: def payload(self) -> torch.Tensor:
...@@ -297,7 +293,7 @@ class Chunk: ...@@ -297,7 +293,7 @@ class Chunk:
self.valid_end = self.utilized_size - self.shard_begin self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == "cpu": if self.chunk_temp.device.type == "cpu":
self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device())
self.__update_tensors_ptr() self.__update_tensors_ptr()
else: else:
self.cuda_global_chunk = self.chunk_temp self.cuda_global_chunk = self.chunk_temp
...@@ -334,12 +330,12 @@ class Chunk: ...@@ -334,12 +330,12 @@ class Chunk:
return return
if device.type == "cuda" or device.type == "npu": if device.type == "cuda" or device.type == "npu":
assert device == get_current_device(), "can't move chunk to another device" assert device == get_accelerator().get_current_device(), "can't move chunk to another device"
if self.cuda_shard: if self.cuda_shard:
return return
self.cuda_shard = self.cpu_shard.to(get_current_device()) self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
if not self.pin_memory: if not self.pin_memory:
self.cpu_shard = None self.cpu_shard = None
...@@ -394,7 +390,9 @@ class Chunk: ...@@ -394,7 +390,9 @@ class Chunk:
if self.extra_dp_group is not None: if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else: else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) self.cuda_shard = torch.empty(
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
)
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
...@@ -533,7 +531,7 @@ class Chunk: ...@@ -533,7 +531,7 @@ class Chunk:
# only be called when optimizer state is in CPU memory # only be called when optimizer state is in CPU memory
# the grad and param should be in the same device # the grad and param should be in the same device
assert self.cuda_shard is None assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_current_device()) temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
# avoid to transform FP32 in CPU # avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype) self.cuda_shard = temp.to(self.dtype)
...@@ -631,7 +629,7 @@ class Chunk: ...@@ -631,7 +629,7 @@ class Chunk:
grad_chunk.valid_end = self.valid_end grad_chunk.valid_end = self.valid_end
if grad_chunk.chunk_temp.device.type == "cpu": if grad_chunk.chunk_temp.device.type == "cpu":
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device()) grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device())
else: else:
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
grad_chunk.chunk_temp = None grad_chunk.chunk_temp = None
......
...@@ -5,7 +5,8 @@ import torch ...@@ -5,7 +5,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.utils import free_storage, get_current_device from colossalai.accelerator import get_accelerator
from colossalai.utils import free_storage
from .chunk import Chunk, ChunkFullError, TensorState from .chunk import Chunk, ChunkFullError, TensorState
...@@ -20,7 +21,7 @@ class ChunkManager: ...@@ -20,7 +21,7 @@ class ChunkManager:
""" """
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device() self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items(): for k, v in self.kwargs_config.items():
...@@ -107,7 +108,7 @@ class ChunkManager: ...@@ -107,7 +108,7 @@ class ChunkManager:
return return
self.__sub_memory_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu": if chunk.device_type == "cpu":
chunk.shard_move(get_current_device()) chunk.shard_move(get_accelerator().get_current_device())
self.__add_accessed_chunk(chunk) self.__add_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
...@@ -276,7 +277,10 @@ class ChunkManager: ...@@ -276,7 +277,10 @@ class ChunkManager:
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
else: else:
accumulated_grad = ( accumulated_grad = (
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device())
.clone()
.detach()
.mul_(chunk.pg_size)
) )
accumulated_grad_gathered = False accumulated_grad_gathered = False
......
...@@ -10,6 +10,7 @@ import torch.nn as nn ...@@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor from colossalai.lazy import LazyTensor
...@@ -27,7 +28,7 @@ from colossalai.tensor.d_tensor import ( ...@@ -27,7 +28,7 @@ from colossalai.tensor.d_tensor import (
is_distributed_tensor, is_distributed_tensor,
) )
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook from .gemini_hook import GeminiZeROHook
...@@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper): ...@@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper):
# move ignored parameters to CUDA # move ignored parameters to CUDA
if is_ddp_ignored(p): if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision)
continue continue
# create a fp16 parameter # create a fp16 parameter
...@@ -815,7 +816,7 @@ class GeminiDDP(ModelWrapper): ...@@ -815,7 +816,7 @@ class GeminiDDP(ModelWrapper):
for buffer in self.module.buffers(): for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor): if isinstance(buffer, LazyTensor):
buffer.materialize() buffer.materialize()
buffer.data = buffer.to(get_current_device()) buffer.data = buffer.to(get_accelerator().get_current_device())
if torch.is_floating_point(buffer): if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision) buffer.data = buffer.to(self.mixed_precision)
......
...@@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup ...@@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
...@@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import ( ...@@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
) )
from colossalai.utils import disposable, get_current_device, is_ddp_ignored from colossalai.utils import disposable, is_ddp_ignored
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP from .gemini_ddp import GeminiDDP
...@@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper):
grad_chunk.l2_norm = None # clear l2 norm grad_chunk.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
for group, part_norm in group_to_norm.items(): for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm) comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group) dist.all_reduce(comm_buffer, group=group)
...@@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper):
continue continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(chunk32, get_current_device()) self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
# stores grad now # stores grad now
self.chunk_manager.move_chunk(chunk16, get_current_device()) self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
self.module.set_chunk_grad_device(chunk16, get_current_device()) self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
fp32_params_used_cuda_margin_mem += chunk32.payload_mem fp32_params_used_cuda_margin_mem += chunk32.payload_mem
for group in self.param_groups: for group in self.param_groups:
...@@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper):
state = self.optim.state[fake_param] state = self.optim.state[fake_param]
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
state[k] = v.to(get_current_device()) state[k] = v.to(get_accelerator().get_current_device())
def _register_states_(self): def _register_states_(self):
for group in self.optim.param_groups: for group in self.optim.param_groups:
...@@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper):
self, self,
param_id: int, param_id: int,
state_names: list, state_names: list,
device: torch.device = get_current_device(), device: torch.device = get_accelerator().get_current_device(),
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
......
from typing import Optional from typing import Optional
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from colossalai.zero.gemini.chunk import ChunkManager from colossalai.zero.gemini.chunk import ChunkManager
from .memory_stats import MemStats from .memory_stats import MemStats
...@@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector): ...@@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector):
def cuda_margin_mem(self) -> float: def cuda_margin_mem(self) -> float:
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda
...@@ -5,7 +5,7 @@ from time import sleep, time ...@@ -5,7 +5,7 @@ from time import sleep, time
import torch import torch
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
class MemoryMonitor: class MemoryMonitor:
...@@ -77,7 +77,7 @@ class AsyncMemoryMonitor(MemoryMonitor): ...@@ -77,7 +77,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
super().__init__() super().__init__()
self.keep_measuring = False self.keep_measuring = False
current_device = get_current_device() current_device = get_accelerator().get_current_device()
def _set_cuda_device(): def _set_cuda_device():
torch.cuda.set_device(current_device) torch.cuda.set_device(current_device)
...@@ -116,7 +116,7 @@ class AsyncMemoryMonitor(MemoryMonitor): ...@@ -116,7 +116,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
while self.keep_measuring: while self.keep_measuring:
max_usage = max( max_usage = max(
max_usage, max_usage,
colo_device_memory_used(get_current_device()), colo_device_memory_used(get_accelerator().get_current_device()),
) )
sleep(self.interval) sleep(self.interval)
return max_usage return max_usage
......
...@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type ...@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.zero.gemini.chunk import Chunk from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
...@@ -85,7 +85,7 @@ class StaticPlacementPolicy(PlacementPolicy): ...@@ -85,7 +85,7 @@ class StaticPlacementPolicy(PlacementPolicy):
# init offload optim settings # init offload optim settings
# keep gathered chunks are in CUDA # keep gathered chunks are in CUDA
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
device = get_current_device() device = get_accelerator().get_current_device()
else: else:
device = torch.device("cpu") device = torch.device("cpu")
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
...@@ -140,7 +140,7 @@ class AutoPlacementPolicy(PlacementPolicy): ...@@ -140,7 +140,7 @@ class AutoPlacementPolicy(PlacementPolicy):
int: the volume of memory that is evicted int: the volume of memory that is evicted
""" """
start = time() start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device()) cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"] used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
if warmup: if warmup:
# We designate a part of CUDA memory for model data in warmup iterations. # We designate a part of CUDA memory for model data in warmup iterations.
...@@ -194,7 +194,7 @@ class AutoPlacementPolicy(PlacementPolicy): ...@@ -194,7 +194,7 @@ class AutoPlacementPolicy(PlacementPolicy):
# init offload optim settings # init offload optim settings
# keep gathered chunks are in CUDA # keep gathered chunks are in CUDA
if chunk.keep_gathered: if chunk.keep_gathered:
grads_device_map[p] = get_current_device() grads_device_map[p] = get_accelerator().get_current_device()
else: else:
grads_device_map[p] = torch.device("cpu") grads_device_map[p] = torch.device("cpu")
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from .chunk import Chunk from .chunk import Chunk
...@@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype): ...@@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
if chunk.cuda_shard is not None: if chunk.cuda_shard is not None:
shard_temp = chunk.cuda_shard shard_temp = chunk.cuda_shard
else: else:
shard_temp = chunk.cpu_shard.to(get_current_device()) shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device())
shard_temp = shard_temp.to(dtype) shard_temp = shard_temp.to(dtype)
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device()) total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device())
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
......
...@@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
import colossalai.utils.device as device_utils from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import ( from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin, BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin, FP16MixedPrecisionMixin,
...@@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper ...@@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore from .bookkeeping import BucketStore, GradientStore, ParameterStore
...@@ -183,7 +180,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -183,7 +180,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# intialize communication stream for # intialize communication stream for
# communication-compuation overlapping # communication-compuation overlapping
if self._overlap_communication: if self._overlap_communication:
self._comm_stream = device_utils.Stream() self._comm_stream = get_accelerator().Stream()
# reduction hook is only used if overlapping communication # reduction hook is only used if overlapping communication
# or stage 2 is used # or stage 2 is used
...@@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return len(self._working_param_groups) return len(self._working_param_groups)
def _sanity_checks(self): def _sanity_checks(self):
assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required" assert get_accelerator().name in ["cuda", "npu"], "device is required"
for param_group in self.optim.param_groups: for param_group in self.optim.param_groups:
group_params = param_group["params"] group_params = param_group["params"]
for param in group_params: for param in group_params:
...@@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _create_master_param_current_rank(self, param_list): def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size # split each param evenly by world size
params_current_rank = [] params_current_rank = []
device = "cpu" if self._cpu_offload else get_current_device() device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
for param in param_list: for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
...@@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(moe_grad_list) > 0: if len(moe_grad_list) > 0:
moe_flat_grads.record_stream(stream) moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing # waiting for ops in the default stream finishing
stream.wait_stream(device_utils.current_stream()) stream.wait_stream(get_accelerator().current_stream())
else: else:
stream = device_utils.current_stream() stream = get_accelerator().current_stream()
with device_utils.stream(stream): with get_accelerator().stream(stream):
group_id = self._bucket_store.current_group_id group_id = self._bucket_store.current_group_id
if self.moe_extra_dp_pg is None: if self.moe_extra_dp_pg is None:
...@@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
device_utils.synchronize() get_accelerator().synchronize()
self.zero_grad() self.zero_grad()
...@@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
device_utils.synchronize() get_accelerator().synchronize()
self.zero_grad() self.zero_grad()
...@@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
release_param_grad(self._master_param_groups_of_current_rank[group_id]) release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank # update working partition updated by the current rank
device = get_current_device() device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"] master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param): for idx, splited_param in enumerate(master_working_param):
...@@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
norm_type = float(norm_type) norm_type = float(norm_type)
if norm_type == inf: if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients) total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
...@@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.tensor( total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
) )
torch.distributed.all_reduce( torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
...@@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
Dict: the pytorch form state_dict Dict: the pytorch form state_dict
""" """
zero_state = dict() zero_state = dict()
device = get_current_device() device = get_accelerator().get_current_device()
for param, state in self.optim.state.items(): for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state) zero_state[param] = copy.deepcopy(state)
for k, v in state.items(): for k, v in state.items():
...@@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block = dict() ret_block = dict()
ret_block_size = 0 ret_block_size = 0
device = get_current_device() device = get_accelerator().get_current_device()
local_states = self.optim.state_dict()["state"] local_states = self.optim.state_dict()["state"]
for param_idx, states in local_states.items(): for param_idx, states in local_states.items():
current_block_size = 0 current_block_size = 0
......
...@@ -45,7 +45,6 @@ from colossalai.booster import Booster ...@@ -45,7 +45,6 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
``` ```
## Define Plugin ## Define Plugin
Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously. Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously.
...@@ -149,7 +148,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost( ...@@ -149,7 +148,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost(
## Training GPT-2 using hybrid parallelism ## Training GPT-2 using hybrid parallelism
In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training.
Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training. Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training.
```python ```python
def train_epoch( def train_epoch(
...@@ -204,4 +203,4 @@ Training the gpt-2 model ...@@ -204,4 +203,4 @@ Training the gpt-2 model
for epoch in range(NUM_EPOCHS): for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
``` ```
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py --> <!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
\ No newline at end of file
...@@ -43,7 +43,6 @@ from colossalai.booster import Booster ...@@ -43,7 +43,6 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
``` ```
### 定义plugin ### 定义plugin
定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1. 定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1.
...@@ -201,4 +200,4 @@ def train_epoch( ...@@ -201,4 +200,4 @@ def train_epoch(
for epoch in range(NUM_EPOCHS): for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
``` ```
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py --> <!-- doc-test-command: torchrun --standalone --nproc_per_node=1 train_gpt_using_hybrid_parallelism.py -->
\ No newline at end of file
...@@ -16,10 +16,10 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_var ...@@ -16,10 +16,10 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_var
from utils.logger import Logger from utils.logger import Logger
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
...@@ -53,7 +53,7 @@ def main(): ...@@ -53,7 +53,7 @@ def main():
set_global_variables(launch_time, args.tensorboard_path) set_global_variables(launch_time, args.tensorboard_path)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
get_current_device() get_accelerator().get_current_device()
# build model, optimizer and criterion # build model, optimizer and criterion
if args.distplan.startswith("CAI"): if args.distplan.startswith("CAI"):
...@@ -67,7 +67,10 @@ def main(): ...@@ -67,7 +67,10 @@ def main():
# build GPT model # build GPT model
with ColoInitContext( with ColoInitContext(
device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg device=get_accelerator().get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg,
): ):
config, model, numel = get_model(args, logger) config, model, numel = get_model(args, logger)
...@@ -78,7 +81,7 @@ def main(): ...@@ -78,7 +81,7 @@ def main():
elif args.distplan == "CAI_Gemini": elif args.distplan == "CAI_Gemini":
gemini_config = dict( gemini_config = dict(
strict_ddp_mode=args.tp_degree == 1, strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(), device=get_accelerator().get_current_device(),
placement_policy=args.placement, placement_policy=args.placement,
pin_memory=True, pin_memory=True,
hidden_dim=model.config.hidden_size, hidden_dim=model.config.hidden_size,
......
...@@ -20,11 +20,11 @@ from tqdm.auto import tqdm ...@@ -20,11 +20,11 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
...@@ -386,7 +386,7 @@ def main(args): ...@@ -386,7 +386,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir())) cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
...@@ -401,7 +401,7 @@ def main(args): ...@@ -401,7 +401,7 @@ def main(args):
sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
pipeline.to(get_current_device()) pipeline.to(get_accelerator().get_current_device())
for example in tqdm( for example in tqdm(
sample_dataloader, sample_dataloader,
...@@ -578,8 +578,8 @@ def main(args): ...@@ -578,8 +578,8 @@ def main(args):
# Move text_encode and vae to gpu. # Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
vae.to(get_current_device(), dtype=weight_dtype) vae.to(get_accelerator().get_current_device(), dtype=weight_dtype)
text_encoder.to(get_current_device(), dtype=weight_dtype) text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader)) num_update_steps_per_epoch = math.ceil(len(train_dataloader))
...@@ -613,7 +613,7 @@ def main(args): ...@@ -613,7 +613,7 @@ def main(args):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
# Move batch to gpu # Move batch to gpu
for key, value in batch.items(): for key, value in batch.items():
batch[key] = value.to(get_current_device(), non_blocking=True) batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True)
# Convert images to latent space # Convert images to latent space
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -21,13 +21,13 @@ from tqdm.auto import tqdm ...@@ -21,13 +21,13 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
...@@ -385,7 +385,7 @@ def main(args): ...@@ -385,7 +385,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir())) cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
...@@ -400,7 +400,7 @@ def main(args): ...@@ -400,7 +400,7 @@ def main(args):
sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
pipeline.to(get_current_device()) pipeline.to(get_accelerator().get_current_device())
for example in tqdm( for example in tqdm(
sample_dataloader, sample_dataloader,
...@@ -598,8 +598,8 @@ def main(args): ...@@ -598,8 +598,8 @@ def main(args):
# Move text_encode and vae to gpu. # Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
vae.to(get_current_device(), dtype=weight_dtype) vae.to(get_accelerator().get_current_device(), dtype=weight_dtype)
text_encoder.to(get_current_device(), dtype=weight_dtype) text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader)) num_update_steps_per_epoch = math.ceil(len(train_dataloader))
...@@ -633,7 +633,7 @@ def main(args): ...@@ -633,7 +633,7 @@ def main(args):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
# Move batch to gpu # Move batch to gpu
for key, value in batch.items(): for key, value in batch.items():
batch[key] = value.to(get_current_device(), non_blocking=True) batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True)
# Convert images to latent space # Convert images to latent space
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -13,12 +13,12 @@ from torch.utils.data import DataLoader ...@@ -13,12 +13,12 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
# ============================== # ==============================
# Prepare Hyperparameters # Prepare Hyperparameters
...@@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl ...@@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
@torch.no_grad() @torch.no_grad()
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
model.eval() model.eval()
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
for images, labels in test_dataloader: for images, labels in test_dataloader:
images = images.cuda() images = images.cuda()
labels = labels.cuda() labels = labels.cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment