Unverified Commit d202cc28 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] change device to accelerator api (#5239)



* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------
Co-authored-by: default avatarXuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: default avatarzxl <43881818+oahzxl@users.noreply.github.com>
parent dd2c28a3
...@@ -47,7 +47,7 @@ jobs: ...@@ -47,7 +47,7 @@ jobs:
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/ options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10 timeout-minutes: 15
steps: steps:
- name: 📚 Checkout - name: 📚 Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
......
...@@ -79,7 +79,7 @@ jobs: ...@@ -79,7 +79,7 @@ jobs:
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/ options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10 timeout-minutes: 15
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }}
cancel-in-progress: true cancel-in-progress: true
......
...@@ -35,7 +35,7 @@ jobs: ...@@ -35,7 +35,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
timeout-minutes: 10 timeout-minutes: 15
steps: steps:
- name: 📚 Checkout - name: 📚 Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
......
...@@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler ...@@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from .base import OnPolicyTrainer from .base import OnPolicyTrainer
from .callbacks import Callback from .callbacks import Callback
...@@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer): ...@@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer):
self.critic_optim = critic_optim self.critic_optim = critic_optim
self.offload_inference_models = offload_inference_models self.offload_inference_models = offload_inference_models
self.device = get_current_device() self.device = get_accelerator().get_current_device()
def _before_fit( def _before_fit(
self, self,
......
...@@ -6,7 +6,6 @@ import torch.nn as nn ...@@ -6,7 +6,6 @@ import torch.nn as nn
import colossalai import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.utils import get_current_device
from colossalai.zero.gemini.gemini_ddp import GeminiDDP from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy from .ddp import DDPStrategy
...@@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy): ...@@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy):
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# colossalai has changed api for get_current_device in 0.3.4 version or newer
try:
from colossalai.accelerator import get_accelerator
chunk_init_device = get_accelerator().get_current_device()
except:
from colossalai.utils import get_current_device
chunk_init_device = get_current_device()
# NOTE: dist should be initialized before calling get_current_device() # NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin( plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=get_current_device(), chunk_init_device=chunk_init_device,
placement_policy=placement_policy, placement_policy=placement_policy,
shard_param_frac=shard_param_frac, shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac, offload_optim_frac=offload_optim_frac,
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
""" """
import json
import argparse import argparse
import json
import os import os
import resource import resource
from contextlib import nullcontext from contextlib import nullcontext
from tqdm import tqdm
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_distributed_dataloader,
)
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import ( from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
GeminiPlugin,
LowLevelZeroPlugin,
HybridParallelPlugin,
)
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossal_llama2.dataset.loader import (
load_tokenized_dataset,
setup_distributed_dataloader,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
)
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
def get_model_numel(model: torch.nn.Module) -> int: def get_model_numel(model: torch.nn.Module) -> int:
...@@ -215,9 +208,18 @@ def main() -> None: ...@@ -215,9 +208,18 @@ def main() -> None:
# ====================================================== # ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler # Initialize Model, Objective, Optimizer and LR Scheduler
# ====================================================== # ======================================================
init_ctx = (
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() # colossalai has changed api for get_current_device in 0.3.4 version or newer
) try:
from colossalai.accelerator import get_accelerator
current_device = get_accelerator().get_current_device()
except:
from colossalai.utils import get_current_device
current_device = get_current_device()
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
with init_ctx: with init_ctx:
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
# Freeze part of parameters. # Freeze part of parameters.
...@@ -320,7 +322,7 @@ def main() -> None: ...@@ -320,7 +322,7 @@ def main() -> None:
initial=start_step, initial=start_step,
) as pbar: ) as pbar:
for step, batch in pbar: for step, batch in pbar:
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch) batch_output = model(**batch)
...@@ -372,9 +374,7 @@ def main() -> None: ...@@ -372,9 +374,7 @@ def main() -> None:
# Final save. # Final save.
coordinator.print_on_master("Start saving final model checkpoint") coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master( coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
......
from .api import auto_set_accelerator, get_accelerator, set_accelerator from .api import auto_set_accelerator, get_accelerator, set_accelerator
from .base_accelerator import BaseAccelerator from .base_accelerator import BaseAccelerator
from .cpu_accelerator import CpuAccelerator
from .cuda_accelerator import CudaAccelerator from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator from .npu_accelerator import NpuAccelerator
...@@ -10,4 +11,5 @@ __all__ = [ ...@@ -10,4 +11,5 @@ __all__ = [
"BaseAccelerator", "BaseAccelerator",
"CudaAccelerator", "CudaAccelerator",
"NpuAccelerator", "NpuAccelerator",
"CpuAccelerator",
] ]
...@@ -3,6 +3,7 @@ from collections import OrderedDict ...@@ -3,6 +3,7 @@ from collections import OrderedDict
from typing import Union from typing import Union
from .base_accelerator import BaseAccelerator from .base_accelerator import BaseAccelerator
from .cpu_accelerator import CpuAccelerator
from .cuda_accelerator import CudaAccelerator from .cuda_accelerator import CudaAccelerator
from .npu_accelerator import NpuAccelerator from .npu_accelerator import NpuAccelerator
...@@ -15,7 +16,7 @@ _ACCELERATOR = None ...@@ -15,7 +16,7 @@ _ACCELERATOR = None
# we use ordered dictionary here to associate the # we use ordered dictionary here to associate the
# order with device check priority # order with device check priority
# i.e. auto_set_accelerator will check cuda first # i.e. auto_set_accelerator will check cuda first
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator) _ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
...@@ -43,19 +44,17 @@ def auto_set_accelerator() -> None: ...@@ -43,19 +44,17 @@ def auto_set_accelerator() -> None:
""" """
global _ACCELERATOR global _ACCELERATOR
for _, accelerator_cls in _ACCELERATOR_MAPPING.items(): for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
try: try:
accelerator = accelerator_cls() accelerator = accelerator_cls()
if accelerator.is_available(): if accelerator_name == "cpu" or accelerator.is_available():
_ACCELERATOR = accelerator _ACCELERATOR = accelerator
break break
except: except:
pass pass
if _ACCELERATOR is None: if _ACCELERATOR is None:
raise RuntimeError( raise RuntimeError("No accelerator is available.")
f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}"
)
def get_accelerator() -> BaseAccelerator: def get_accelerator() -> BaseAccelerator:
......
#!/usr/bin/env python #!/usr/bin/env python
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -8,6 +9,8 @@ __all__ = ["BaseAccelerator"] ...@@ -8,6 +9,8 @@ __all__ = ["BaseAccelerator"]
class BaseAccelerator(ABC): class BaseAccelerator(ABC):
support_set_device: bool = True
def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None: def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
self._name = name self._name = name
self._communication_backend = communication_backend self._communication_backend = communication_backend
...@@ -45,6 +48,12 @@ class BaseAccelerator(ABC): ...@@ -45,6 +48,12 @@ class BaseAccelerator(ABC):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
@abstractmethod
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
@abstractmethod @abstractmethod
def current_device(self) -> int: def current_device(self) -> int:
""" """
...@@ -52,7 +61,7 @@ class BaseAccelerator(ABC): ...@@ -52,7 +61,7 @@ class BaseAccelerator(ABC):
""" """
@abstractmethod @abstractmethod
def set_device(self, device: Union[torch.device, int]) -> None: def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
""" """
Bind the current process to a device. Bind the current process to a device.
""" """
...@@ -79,3 +88,226 @@ class BaseAccelerator(ABC): ...@@ -79,3 +88,226 @@ class BaseAccelerator(ABC):
""" """
Return the number of devices on the machine. Return the number of devices on the machine.
""" """
def set_to_device(self, models: Any) -> Any:
"""
Send model to device.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(self.get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(self.get_current_device())
else:
return models.to(self.get_current_device())
@abstractmethod
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the capability of a device.
"""
@abstractmethod
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
@abstractmethod
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
@abstractmethod
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc.
"""
# =======================
# random number generator APIs
# =======================
@abstractmethod
def get_rng_state(self, device="cuda") -> torch.Tensor:
"""
Returns the random number generator state of the specified device as a ByteTensor.
"""
@abstractmethod
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
@abstractmethod
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
"""
Sets the random number generator state of the specified device.
"""
@abstractmethod
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
@abstractmethod
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current device.
"""
@abstractmethod
def manual_seed_all(self, seed: int) -> None:
"""
Sets the seed for generating random numbers on all devices.
"""
@abstractmethod
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current device.
"""
@abstractmethod
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all devices.
"""
@abstractmethod
def initial_seed(self) -> int:
"""
Returns the current random seed of the current device.
"""
# =======================
# memory management APIs
# =======================
@abstractmethod
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi.
"""
@abstractmethod
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
@abstractmethod
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
@abstractmethod
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
@abstractmethod
def memory_allocated(self, device=None) -> int:
"""
Returns the current device memory occupied by tensors in bytes for a given device.
"""
@abstractmethod
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum device memory occupied by tensors in bytes for a given device.
"""
@abstractmethod
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum device memory occupied by tensors for a given device.
"""
@abstractmethod
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device.
"""
@abstractmethod
def memory_reserved(self, device=None) -> int:
"""
Returns the current device memory managed by the caching allocator in bytes for a given device.
"""
@abstractmethod
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum device memory managed by the caching allocator in bytes for a given device.
"""
@abstractmethod
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
@abstractmethod
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the device memory allocator.
"""
# =======================
# streams and events APIs
# =======================
@abstractmethod
def Stream(self, device=None, priority=0, **kwargs):
"""
A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
@abstractmethod
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
@abstractmethod
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
@abstractmethod
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
@abstractmethod
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
@abstractmethod
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
# =======================
# amp APIs
# =======================
@abstractmethod
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
#!/usr/bin/env python
import resource
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import psutil
import torch
from .base_accelerator import BaseAccelerator
__all__ = ["CpuAccelerator"]
class CpuAccelerator(BaseAccelerator):
support_set_device: bool = False
"""
Accelerator class for cpu.
"""
def __init__(self):
super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False)
# =======================
# device APIs
# =======================
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device("cpu")
def current_device(self) -> int:
"""
Return the current device index.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
"""
Bind the current process to a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_name(self, device: Union[torch.device, int]) -> str:
"""
Return the name of the device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def synchronize(self, device: Union[torch.device, int] = None):
"""
Synchronize the current process.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def is_available(self):
"""
Check if the accelerator is available.
"""
return True
def device_count(self):
"""
Return the number of devices on the machine.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the cuda capability of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device=None) -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.set_rng_state(new_state)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return psutil.Process().memory_info().rss
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
max_memory = int(psutil.virtual_memory().total * fraction)
_, hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard))
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the CUDA memory allocator.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
raise RuntimeError("this method is not supported for cpu accelerator")
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return nullcontext
#!/usr/bin/env python #!/usr/bin/env python
from typing import Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist
from .base_accelerator import BaseAccelerator from .base_accelerator import BaseAccelerator
...@@ -19,16 +21,26 @@ class CudaAccelerator(BaseAccelerator): ...@@ -19,16 +21,26 @@ class CudaAccelerator(BaseAccelerator):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device(f"cuda:{torch.cuda.current_device()}")
def current_device(self) -> int: def current_device(self) -> int:
""" """
Return the current device index. Return the current device index.
""" """
return torch.cuda.current_device() return torch.cuda.current_device()
def set_device(self, device: Union[torch.device, int]) -> None: def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
""" """
Bind the current process to a device. Bind the current process to a device.
""" """
if device is None:
if not dist.is_initialized():
raise RuntimeError("Cannot get current device when distributed is not initialized.")
device = dist.get_rank() % self.device_count()
torch.cuda.set_device(device) torch.cuda.set_device(device)
def get_device_name(self, device: Union[torch.device, int]) -> str: def get_device_name(self, device: Union[torch.device, int]) -> str:
...@@ -54,3 +66,211 @@ class CudaAccelerator(BaseAccelerator): ...@@ -54,3 +66,211 @@ class CudaAccelerator(BaseAccelerator):
Return the number of devices on the machine. Return the number of devices on the machine.
""" """
return torch.cuda.device_count() return torch.cuda.device_count()
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the cuda capability of a device.
"""
return torch.cuda.get_device_capability(device)
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
return torch.cuda.get_device_name(device)
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
return torch.cuda.get_device_properties(device)
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
return torch.cuda.utilization(device)
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device="cuda") -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.cuda.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
return torch.cuda.get_rng_state_all()
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.cuda.set_rng_state(new_state, device)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
torch.cuda.set_rng_state_all(new_states)
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
torch.cuda.manual_seed(seed)
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
torch.cuda.manual_seed_all(seed)
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
torch.cuda.seed()
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
torch.cuda.seed_all()
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
return torch.cuda.initial_seed()
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
torch.cuda.empty_cache()
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of CUDA memory allocator statistics for a given device.
"""
return torch.cuda.memory_stats(device=device)
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
return torch.cuda.memory_summary(device=device, abbreviated=abbreviated)
def memory_snapshot(self):
"""
Returns a snapshot of the CUDA memory allocator state across all devices.
"""
return torch.cuda.memory_snapshot()
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
return torch.cuda.memory_allocated(device=device)
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
return torch.cuda.max_memory_allocated(device=device)
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
torch.cuda.reset_max_memory_allocated(device=device)
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
torch.cuda.reset_max_memory_cached(device=device)
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.cuda.memory_reserved(device=device)
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.cuda.max_memory_reserved(device=device)
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
torch.cuda.set_per_process_memory_fraction(fraction, device=device)
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the CUDA memory allocator.
"""
torch.cuda.reset_peak_memory_stats(device=device)
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
"""
return torch.cuda.Stream(device, priority, **kwargs)
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
"""
return torch.cuda.Event(enable_timing, blocking, interprocess)
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
return torch.cuda.current_stream(device)
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
return torch.cuda.default_stream(device)
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
torch.cuda.set_stream(stream_)
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
return torch.cuda.stream(stream_)
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
#!/usr/bin/env python #!/usr/bin/env python
from typing import Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist
from .base_accelerator import BaseAccelerator from .base_accelerator import BaseAccelerator
IS_NPU_AVAILABLE = False
try: try:
import torch_npu # noqa import torch_npu # noqa
IS_NPU_AVAILABLE = True
except ImportError: except ImportError:
pass pass
...@@ -26,16 +30,26 @@ class NpuAccelerator(BaseAccelerator): ...@@ -26,16 +30,26 @@ class NpuAccelerator(BaseAccelerator):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
def get_current_device(self) -> torch.device:
"""
Return the current device.
"""
return torch.device(f"npu:{torch.npu.current_device()}")
def current_device(self) -> int: def current_device(self) -> int:
""" """
Return the current device index. Return the current device index.
""" """
return torch.npu.current_device() return torch.npu.current_device()
def set_device(self, device: Union[torch.device, int]) -> None: def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
""" """
Bind the current process to a device. Bind the current process to a device.
""" """
if device is None:
if not dist.is_initialized():
raise RuntimeError("Cannot get current device when distributed is not initialized.")
device = dist.get_rank() % self.device_count()
torch.npu.set_device(device) torch.npu.set_device(device)
def get_device_name(self, device: Union[torch.device, int]) -> str: def get_device_name(self, device: Union[torch.device, int]) -> str:
...@@ -61,3 +75,211 @@ class NpuAccelerator(BaseAccelerator): ...@@ -61,3 +75,211 @@ class NpuAccelerator(BaseAccelerator):
Return the number of devices on the machine. Return the number of devices on the machine.
""" """
return torch.npu.device_count() return torch.npu.device_count()
def get_device_capability(self, device=None) -> Tuple[int, int]:
"""
Gets the npu capability of a device.
"""
return torch.npu.get_device_capability(device)
def get_device_name(self, device=None) -> str:
"""
Gets the name of a device.
"""
return torch.npu.get_device_name(device)
def get_device_properties(self, device):
"""
Gets the properties of a device.
"""
return torch.npu.get_device_properties(device)
def utilization(self, device=None) -> int:
"""
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
"""
return torch.npu.utilization(device)
# =======================
# random number generator APIs
# =======================
def get_rng_state(self, device="npu") -> torch.Tensor:
"""
Returns the random number generator state of the specified GPU as a ByteTensor.
"""
return torch.npu.get_rng_state(device)
def get_rng_state_all(self) -> List[torch.Tensor]:
"""
Returns a list of ByteTensor representing the random number states of all devices.
"""
return torch.npu.get_rng_state_all()
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None:
"""
Sets the random number generator state of the specified GPU.
"""
torch.npu.set_rng_state(new_state, device)
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
"""
Sets the random number generator state of all devices.
"""
torch.npu.set_rng_state_all(new_states)
def manual_seed(self, seed: int) -> None:
"""
Sets the seed for generating random numbers for the current GPU.
"""
torch.npu.manual_seed(seed)
def manual_seed_all(self, seed: int) -> None:
"""
Set the random seed for the all processes.
"""
torch.npu.manual_seed_all(seed)
def seed(self) -> None:
"""
Sets the seed for generating random numbers to a random number for the current GPU.
"""
torch.npu.seed()
def seed_all(self) -> None:
"""
Sets the seed for generating random numbers to a random number on all GPUs.
"""
torch.npu.seed_all()
def initial_seed(self) -> int:
"""
Returns the current random seed of the current GPU.
"""
return torch.npu.initial_seed()
# =======================
# memory management APIs
# =======================
def empty_cache(self) -> None:
"""
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
"""
torch.npu.empty_cache()
def memory_stats(self, device=None) -> Dict[str, Any]:
"""
Returns a dictionary of npu memory allocator statistics for a given device.
"""
return torch.npu.memory_stats(device=device)
def memory_summary(self, device=None, abbreviated=False) -> str:
"""
Returns a human-readable printout of the current memory allocator statistics for a given device.
"""
return torch.npu.memory_summary(device=device, abbreviated=abbreviated)
def memory_snapshot(self):
"""
Returns a snapshot of the npu memory allocator state across all devices.
"""
return torch.npu.memory_snapshot()
def memory_allocated(self, device=None) -> int:
"""
Returns the current GPU memory occupied by tensors in bytes for a given device.
"""
return torch.npu.memory_allocated(device=device)
def max_memory_allocated(self, device=None) -> int:
"""
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
"""
return torch.npu.max_memory_allocated(device=device)
def reset_max_memory_allocated(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
"""
torch.npu.reset_max_memory_allocated(device=device)
def reset_max_memory_cached(self, device=None) -> None:
"""
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
"""
torch.npu.reset_max_memory_cached(device=device)
def memory_reserved(self, device=None) -> int:
"""
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.npu.memory_reserved(device=device)
def max_memory_reserved(self, device=None) -> int:
"""
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
"""
return torch.npu.max_memory_reserved(device=device)
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
"""
Set memory fraction for a process.
"""
torch.npu.set_per_process_memory_fraction(fraction, device=device)
def reset_peak_memory_stats(self, device=None) -> None:
"""
Resets the "peak" stats tracked by the npu memory allocator.
"""
torch.npu.reset_peak_memory_stats(device=device)
# =======================
# streams and events APIs
# =======================
def Stream(self, device=None, priority=0, **kwargs):
"""
A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details.
"""
return torch.npu.Stream(device, priority, **kwargs)
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
"""
npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams.
"""
return torch.npu.Event(enable_timing, blocking, interprocess)
def current_stream(self, device=None):
"""
Returns the currently selected Stream for a given device.
"""
return torch.npu.current_stream(device)
def default_stream(self, device=None):
"""
Returns the default Stream for a given device.
"""
return torch.npu.default_stream(device)
def set_stream(self, stream_):
"""
Sets the current stream.This is a wrapper API to set the stream.
"""
torch.npu.set_stream(stream_)
def stream(self, stream_):
"""
Wrapper around the Context-manager StreamContext that selects a given stream.
"""
return torch.npu.stream(stream_)
# =======================
# amp APIs
# =======================
def autocast(
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
) -> Callable:
"""
Return autocast function
"""
return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
...@@ -7,8 +7,8 @@ from typing import Dict ...@@ -7,8 +7,8 @@ from typing import Dict
import torch import torch
from torch import Tensor from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"] __all__ = ["BaseGradScaler"]
...@@ -23,7 +23,7 @@ class BaseGradScaler(ABC): ...@@ -23,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool): def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0 assert initial_scale > 0
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
self._verbose = verbose self._verbose = verbose
if self._verbose: if self._verbose:
......
...@@ -5,7 +5,7 @@ from typing import Optional ...@@ -5,7 +5,7 @@ from typing import Optional
import torch import torch
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
from .base_grad_scaler import BaseGradScaler from .base_grad_scaler import BaseGradScaler
...@@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler):
hysteresis: int = 2, hysteresis: int = 2,
verbose: bool = False, verbose: bool = False,
): ):
a = get_accelerator()
a.device_count()
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
if min_scale: if min_scale:
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) self._min_scale = torch.tensor(
[min_scale], device=get_accelerator().get_current_device(), dtype=torch.float
)
else: else:
self._min_scale = None self._min_scale = None
if max_scale: if max_scale:
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) self._max_scale = torch.tensor(
[max_scale], device=get_accelerator().get_current_device(), dtype=torch.float
)
else: else:
self._max_scale = None self._max_scale = None
...@@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler):
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].to(get_current_device()) self._scale = state_dict["scale"].to(get_accelerator().get_current_device())
self._growth_factor = state_dict["growth_factor"] self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"] self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"] self._hysteresis = state_dict["hysteresis"]
...@@ -5,8 +5,8 @@ import torch ...@@ -5,8 +5,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin from .base import MixedPrecisionMixin
...@@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin): ...@@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
max_scale=max_scale, max_scale=max_scale,
) )
self.optim_state = OptimState.UNSCALED self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
@property @property
def loss_scale(self) -> float: def loss_scale(self) -> float:
......
...@@ -4,10 +4,10 @@ from typing import Dict, Tuple ...@@ -4,10 +4,10 @@ from typing import Dict, Tuple
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule from .base_offload_module import BaseOffloadModule
from .region import Region from .region import Region
...@@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper): ...@@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper):
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale, max_scale=max_scale,
) )
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) self._found_overflow: torch.Tensor = torch.zeros(
1, dtype=torch.int64, device=get_accelerator().get_current_device()
)
self._logger = get_dist_logger() self._logger = get_dist_logger()
def _set_grad_ptr(self): def _set_grad_ptr(self):
......
...@@ -11,7 +11,7 @@ except: ...@@ -11,7 +11,7 @@ except:
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
from .region import Region from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
...@@ -57,7 +57,10 @@ class Solver(ABC): ...@@ -57,7 +57,10 @@ class Solver(ABC):
if memory_budget > 0: if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor self.memory_budget = memory_budget * self.error_factor
else: else:
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor self.memory_budget = (
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
* self.error_factor
)
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power() self.comp_power: float = self._extract_computing_power()
......
...@@ -5,8 +5,8 @@ import torch.nn as nn ...@@ -5,8 +5,8 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils.device import autocast
from .mixed_precision_base import MixedPrecision from .mixed_precision_base import MixedPrecision
...@@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper): ...@@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
super().__init__(module) super().__init__(module)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
with autocast(): with get_accelerator().autocast():
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
......
...@@ -12,6 +12,7 @@ from torch.optim import Optimizer ...@@ -12,6 +12,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_model_base_filenames, get_model_base_filenames,
...@@ -24,8 +25,6 @@ from colossalai.checkpoint_io.utils import ( ...@@ -24,8 +25,6 @@ from colossalai.checkpoint_io.utils import (
from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
...@@ -367,7 +366,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -367,7 +366,7 @@ class GeminiPlugin(DPPluginBase):
assert placement_policy == "static", "NPU only supports static placement policy" assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict( self.gemini_config = dict(
chunk_config_dict=chunk_config_dict, chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()), chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
placement_policy=placement_policy, placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation, enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac, shard_param_frac=shard_param_frac,
......
...@@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map ...@@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
...@@ -29,7 +30,6 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils ...@@ -29,7 +30,6 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.utils.device import get_current_device
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
...@@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper): ...@@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
self.mixed_precision = torch.bfloat16 self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None: if self.mixed_precision is not None:
module = module.to(self.mixed_precision) module = module.to(self.mixed_precision)
module = module.to(get_current_device()) module = module.to(get_accelerator().get_current_device())
# setting input type cast when using mixed precision # setting input type cast when using mixed precision
self.convert_fn = None self.convert_fn = None
...@@ -346,7 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): ...@@ -346,7 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if norm_type == inf: if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients) total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if self.pp_size > 1: if self.pp_size > 1:
...@@ -385,7 +387,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): ...@@ -385,7 +387,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
# compute norm in tp process group # compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
...@@ -543,7 +547,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -543,7 +547,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# so we need to calculate the norm of 'tp' and 'pp' gradients. # so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
...@@ -586,7 +592,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): ...@@ -586,7 +592,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1: if self.tp_size > 1:
# compute norm in tp process group # compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
...@@ -798,7 +806,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -798,7 +806,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# so we only need to calculate the norm 'tp' of 'pp' gradients. # so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type) total_norm = super()._compute_grad_norm(gradients, norm_type)
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if tp_size > 1: if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
...@@ -837,7 +847,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): ...@@ -837,7 +847,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if dp_size > 1: if dp_size > 1:
# compute norm in dp process group # compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment