Unverified Commit e5ce4c8e authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
parent 8d56c9c3
...@@ -10,7 +10,7 @@ from torch.utils._pytree import tree_map ...@@ -10,7 +10,7 @@ from torch.utils._pytree import tree_map
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
......
...@@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map ...@@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
......
...@@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map ...@@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._utils import ( from ._utils import (
detach, detach,
......
...@@ -2,16 +2,19 @@ ...@@ -2,16 +2,19 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch.nn as nn import torch.nn as nn
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from ._operation import hook_paramter_in_backward
from ._operation import hook_paramter_in_backward
from .utils import SeqParallelUtils from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
try: try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm from apex.contrib.layer_norm.layer_norm import FastLayerNorm
EnableFastLayerNorm = True EnableFastLayerNorm = True
except ImportError: except ImportError:
EnableFastLayerNorm = False EnableFastLayerNorm = False
...@@ -19,10 +22,27 @@ except ImportError: ...@@ -19,10 +22,27 @@ except ImportError:
try: try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
except ImportError: except ImportError:
warnings.warn( warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
)
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,
...@@ -52,6 +72,7 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [ ...@@ -52,6 +72,7 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
] ]
if EnableFastLayerNorm: if EnableFastLayerNorm:
class FastLayerNormWithHook(FastLayerNorm): class FastLayerNormWithHook(FastLayerNorm):
def __init__(self, hidden_size, eps=0.00001): def __init__(self, hidden_size, eps=0.00001):
super().__init__(hidden_size, eps) super().__init__(hidden_size, eps)
...@@ -60,25 +81,7 @@ if EnableFastLayerNorm: ...@@ -60,25 +81,7 @@ if EnableFastLayerNorm:
output = super().forward(input) output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias) output = hook_paramter_in_backward(output, self.weight, self.bias)
return output return output
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
class BaseLayerNorm(ABC): class BaseLayerNorm(ABC):
@abstractmethod @abstractmethod
...@@ -244,12 +247,13 @@ class FusedRMSNorm(BaseLayerNorm): ...@@ -244,12 +247,13 @@ class FusedRMSNorm(BaseLayerNorm):
""" """
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
""" """
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. " "FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex." "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r""" r"""
...@@ -264,7 +268,7 @@ class FusedRMSNorm(BaseLayerNorm): ...@@ -264,7 +268,7 @@ class FusedRMSNorm(BaseLayerNorm):
nn.Module: FusedRMSNorm module. nn.Module: FusedRMSNorm module.
""" """
try: try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm pass
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
...@@ -282,7 +286,9 @@ class FusedRMSNorm(BaseLayerNorm): ...@@ -282,7 +286,9 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.eps eps = module.eps
elementwise_affine = module.elementwise_affine elementwise_affine = module.elementwise_affine
rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
)
rmsnorm.weight = module.weight rmsnorm.weight = module.weight
......
...@@ -7,7 +7,7 @@ from .common import ( ...@@ -7,7 +7,7 @@ from .common import (
is_ddp_ignored, is_ddp_ignored,
set_seed, set_seed,
) )
from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
from .tensor_detector import TensorDetector from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
...@@ -29,4 +29,5 @@ __all__ = [ ...@@ -29,4 +29,5 @@ __all__ = [
"set_seed", "set_seed",
"is_ddp_ignored", "is_ddp_ignored",
"set_device", "set_device",
"IS_NPU_AVAILABLE",
] ]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
import torch
import torch.distributed as dist
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
else:
return torch.device("cpu")
def synchronize():
"""Similar to cuda.synchronize().
Waits for all kernels in all streams on a CUDA device to complete.
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
def empty_cache():
"""Similar to cuda.empty_cache()
Releases all unoccupied cached memory currently held by the caching allocator.
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % torch.cuda.device_count()
torch.cuda.set_device(index)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
IS_NPU_AVAILABLE: bool = False
try:
import torch_npu # noqa
IS_NPU_AVAILABLE = torch.npu.is_available()
except ImportError:
pass
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif IS_NPU_AVAILABLE:
return torch.device(f"npu:{torch.npu.current_device()}")
else:
return torch.device("cpu")
def _dispatch_device_func(fn_name: str, *args, **kwargs):
if torch.cuda.is_available():
return getattr(torch.cuda, fn_name)(*args, **kwargs)
elif IS_NPU_AVAILABLE:
return getattr(torch.npu, fn_name)(*args, **kwargs)
else:
raise RuntimeError("No device available")
# device semantics
def can_device_access_peer(device, peer_device) -> bool:
return _dispatch_device_func("can_device_access_peer", device, peer_device)
def current_device() -> int:
return _dispatch_device_func("current_device")
def current_stream(device=None):
return _dispatch_device_func("current_stream", device)
def default_stream(device=None):
return _dispatch_device_func("default_stream", device)
def device_count() -> int:
return _dispatch_device_func("device_count")
def get_device_capability(device=None) -> Tuple[int, int]:
return _dispatch_device_func("get_device_capability", device)
def get_device_name(device=None) -> str:
return _dispatch_device_func("get_device_name", device)
def get_device_properties(device):
return _dispatch_device_func("get_device_properties", device)
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % device_count()
_dispatch_device_func("set_device", index)
def set_stream(stream_):
return _dispatch_device_func("set_stream", stream_)
def stream(stream_):
return _dispatch_device_func("stream", stream_)
def synchronize():
return _dispatch_device_func("synchronize")
def utilization(device=None) -> int:
return _dispatch_device_func("utilization", device)
# random number generator
def get_rng_state(device="cuda") -> torch.Tensor:
return _dispatch_device_func("get_rng_state", device)
def get_rng_state_all() -> List[torch.Tensor]:
return _dispatch_device_func("get_rng_state_all")
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
return _dispatch_device_func("set_rng_state", new_state, device)
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
return _dispatch_device_func("set_rng_state_all", new_states)
def manual_seed(seed: int) -> None:
return _dispatch_device_func("manual_seed", seed)
def manual_seed_all(seed: int) -> None:
return _dispatch_device_func("manual_seed_all", seed)
def seed() -> None:
return _dispatch_device_func("seed")
def seed_all() -> None:
return _dispatch_device_func("seed_all")
def initial_seed() -> int:
return _dispatch_device_func("initial_seed")
# streams and events
def Stream(device=None, priority=0, **kwargs):
return _dispatch_device_func("Stream", device, priority, **kwargs)
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
# memory management
def empty_cache() -> None:
return _dispatch_device_func("empty_cache")
def memory_stats(device=None) -> Dict[str, Any]:
return _dispatch_device_func("memory_stats", device)
def memory_summary(device=None, abbreviated=False) -> str:
return _dispatch_device_func("memory_summary", device, abbreviated)
def memory_snapshot():
return _dispatch_device_func("memory_snapshot")
def memory_allocated(device=None) -> int:
return _dispatch_device_func("memory_allocated", device)
def max_memory_allocated(device=None) -> int:
return _dispatch_device_func("max_memory_allocated", device)
def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)
def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)
def max_memory_reserved(device=None) -> int:
return _dispatch_device_func("max_memory_reserved", device)
def set_per_process_memory_fraction(fraction: float, device=None) -> None:
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import time import time
from typing import Tuple from typing import Tuple
from .cuda import synchronize from .device import synchronize
class Timer: class Timer:
......
...@@ -7,6 +7,7 @@ import torch.distributed as dist ...@@ -7,6 +7,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
class TensorState(Enum): class TensorState(Enum):
...@@ -172,7 +173,7 @@ class Chunk: ...@@ -172,7 +173,7 @@ class Chunk:
if self.chunk_temp is not None: if self.chunk_temp is not None:
# this chunk is not closed # this chunk is not closed
if self.chunk_temp.device.type == "cuda": if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu":
cuda_memory += self.chunk_mem cuda_memory += self.chunk_mem
else: else:
cpu_memory += self.chunk_mem cpu_memory += self.chunk_mem
...@@ -191,10 +192,8 @@ class Chunk: ...@@ -191,10 +192,8 @@ class Chunk:
if self.chunk_temp is not None: if self.chunk_temp is not None:
return self.chunk_temp.device.type return self.chunk_temp.device.type
else: else:
if self.is_gathered: if self.is_gathered or self.cuda_shard is not None:
return "cuda" return "npu" if IS_NPU_AVAILABLE else "cuda"
elif self.cuda_shard is not None:
return "cuda"
else: else:
return "cpu" return "cpu"
...@@ -329,12 +328,12 @@ class Chunk: ...@@ -329,12 +328,12 @@ class Chunk:
# when the current chunk is not synchronized with the optimizer # when the current chunk is not synchronized with the optimizer
# just use another way for the movement # just use another way for the movement
if not self.optim_sync_flag: if not self.optim_sync_flag:
assert device.type == "cuda", "each chunk should first be moved to CUDA" assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
self.__paired_shard_move() self.__paired_shard_move()
self.optim_sync_flag = True self.optim_sync_flag = True
return return
if device.type == "cuda": if device.type == "cuda" or device.type == "npu":
assert device == get_current_device(), "can't move chunk to another device" assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard: if self.cuda_shard:
...@@ -484,7 +483,7 @@ class Chunk: ...@@ -484,7 +483,7 @@ class Chunk:
assert friend_chunk.is_gathered is True assert friend_chunk.is_gathered is True
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True self.optim_sync_flag = True
elif friend_chunk.device_type == "cuda" and self.device_type == "cuda": elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"):
self.cuda_shard.copy_(friend_chunk.cuda_shard) self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True self.optim_sync_flag = True
self.cpu_vis_flag = False self.cpu_vis_flag = False
......
...@@ -206,7 +206,10 @@ class ChunkManager: ...@@ -206,7 +206,10 @@ class ChunkManager:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
""" """
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() device_type = tensor.device.type
if device_type == "npu":
device_type = "cuda"
self.total_mem[device_type] += tensor.numel() * tensor.element_size()
def __repr__(self) -> str: def __repr__(self) -> str:
msg = [ msg = [
......
...@@ -10,32 +10,30 @@ import torch.nn as nn ...@@ -10,32 +10,30 @@ import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from colossalai.checkpoint_io.utils import gather_distributed_param
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
distribute_tensor, distribute_tensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh, get_device_mesh,
get_global_shape,
get_sharding_spec, get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
get_global_shape,
init_as_dtensor
) )
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
...@@ -162,7 +160,7 @@ class GeminiDDP(ModelWrapper): ...@@ -162,7 +160,7 @@ class GeminiDDP(ModelWrapper):
self._init_chunks( self._init_chunks(
param_order=param_order, param_order=param_order,
strict_ddp_mode=strict_ddp_mode, strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != "cuda", cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0),
pin_memory=pin_memory, pin_memory=pin_memory,
) )
super().__init__(module) super().__init__(module)
...@@ -453,12 +451,13 @@ class GeminiDDP(ModelWrapper): ...@@ -453,12 +451,13 @@ class GeminiDDP(ModelWrapper):
global_shape = get_global_shape(tensor) global_shape = get_global_shape(tensor)
device_mesh = get_device_mesh(tensor) device_mesh = get_device_mesh(tensor)
shard_spec = get_sharding_spec(tensor) shard_spec = get_sharding_spec(tensor)
record_tensor = init_as_dtensor(record_tensor, record_tensor = init_as_dtensor(
device_mesh=device_mesh, record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape
sharding_spec=shard_spec, )
global_shape = global_shape)
elif is_customized_distributed_tensor(tensor): elif is_customized_distributed_tensor(tensor):
init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn) init_tensor_as_customization_distributed(
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
)
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
assert tensor not in chunk_to_save_data assert tensor not in chunk_to_save_data
...@@ -634,7 +633,15 @@ class GeminiDDP(ModelWrapper): ...@@ -634,7 +633,15 @@ class GeminiDDP(ModelWrapper):
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None} local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None): def load(
param_name,
dest_tensor,
copy_func,
source_device_mesh=None,
source_sharding_spec=None,
shard_fn=None,
gather_fn=None,
):
state_key = prefix + param_name state_key = prefix + param_name
if state_key in state_dict: if state_key in state_dict:
input_param = state_dict[state_key] input_param = state_dict[state_key]
...@@ -642,7 +649,9 @@ class GeminiDDP(ModelWrapper): ...@@ -642,7 +649,9 @@ class GeminiDDP(ModelWrapper):
if source_device_mesh is not None and source_sharding_spec is not None: if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None: elif shard_fn is not None and gather_fn is not None:
input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn) input_param = distribute_tensor_with_customization(
input_param, shard_fn=shard_fn, gather_fn=gather_fn
)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
...@@ -687,7 +696,6 @@ class GeminiDDP(ModelWrapper): ...@@ -687,7 +696,6 @@ class GeminiDDP(ModelWrapper):
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items(): for tensor, tensor_info in chunk.tensors_info.items():
source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
if is_distributed_tensor(tensor): if is_distributed_tensor(tensor):
# shard the input param # shard the input param
...@@ -699,7 +707,15 @@ class GeminiDDP(ModelWrapper): ...@@ -699,7 +707,15 @@ class GeminiDDP(ModelWrapper):
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn) load(
parameter_name,
tensor,
partial(load_parameter, parameter_slice),
source_device_mesh,
source_sharding_spec,
shard_fn,
gather_fn,
)
if chunk.is_gathered: if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk) chunk.cuda_global_chunk.copy_(temp_chunk)
...@@ -799,7 +815,7 @@ class GeminiDDP(ModelWrapper): ...@@ -799,7 +815,7 @@ class GeminiDDP(ModelWrapper):
for buffer in self.module.buffers(): for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor): if isinstance(buffer, LazyTensor):
buffer.materialize() buffer.materialize()
buffer.data = buffer.cuda() buffer.data = buffer.to(get_current_device())
if torch.is_floating_point(buffer): if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision) buffer.data = buffer.to(self.mixed_precision)
......
...@@ -17,9 +17,7 @@ class GeminiManager: ...@@ -17,9 +17,7 @@ class GeminiManager:
https://arxiv.org/abs/2108.05818 https://arxiv.org/abs/2108.05818
Args: Args:
placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. placement_policy (str): Which device to place *held* tensors. It can be 'static' and 'auto'.
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training. Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance. chunk_manager (ChunkManager): A ``ChunkManager`` instance.
...@@ -121,7 +119,7 @@ class GeminiManager: ...@@ -121,7 +119,7 @@ class GeminiManager:
start = time() start = time()
cuda_demand = 0 cuda_demand = 0
for chunk in chunks: for chunk in chunks:
if chunk.device_type == "cuda": if chunk.device_type == "cuda" or chunk.device_type == "npu":
if chunk.is_gathered: if chunk.is_gathered:
pass pass
else: else:
......
...@@ -7,31 +7,29 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union ...@@ -7,31 +7,29 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging.version import Version from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.distributed import ProcessGroup
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
distribute_tensor, distribute_tensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh, get_device_mesh,
get_sharding_spec, get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
get_global_shape,
init_as_dtensor
) )
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] __all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
...@@ -312,7 +310,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -312,7 +310,7 @@ class GeminiOptimizer(OptimizerWrapper):
chunk16 = self.param_to_chunk16[fake_param] chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda": if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
continue continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
...@@ -326,7 +324,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -326,7 +324,7 @@ class GeminiOptimizer(OptimizerWrapper):
for fake_param in group["params"]: for fake_param in group["params"]:
chunk16 = self.param_to_chunk16[fake_param] chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda": if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
state = self.optim.state[fake_param] state = self.optim.state[fake_param]
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
...@@ -479,15 +477,19 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -479,15 +477,19 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
if is_dtensor: if is_dtensor:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor(state_tensor, state_tensor = init_as_dtensor(
device_mesh=device_mesh, state_tensor,
sharding_spec=shard_spec, device_mesh=device_mesh,
global_shape = global_shape) sharding_spec=shard_spec,
global_shape=global_shape,
)
elif is_customized_distributed: elif is_customized_distributed:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
collected_states[state_name] = state_tensor.reshape(global_shape) collected_states[state_name] = state_tensor.reshape(global_shape)
return collected_states return collected_states
...@@ -533,13 +535,14 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -533,13 +535,14 @@ class GeminiOptimizer(OptimizerWrapper):
collected_states[state_name] = torch.reshape(state_tensor, param.shape) collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor: if is_dtensor:
state_tensor = state_tensor.to(param.device) state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor(state_tensor, state_tensor = init_as_dtensor(
sharding_spec=shard_spec, state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
device_mesh=device_mesh, )
global_shape=global_shape)
elif is_customized_distributed: elif is_customized_distributed:
state_tensor = state_tensor.to(param.device) state_tensor = state_tensor.to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
return collected_states return collected_states
...@@ -548,7 +551,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -548,7 +551,7 @@ class GeminiOptimizer(OptimizerWrapper):
self, self,
param_id: int, param_id: int,
state_names: list, state_names: list,
device: torch.device = torch.device("cuda"), device: torch.device = get_current_device(),
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -705,7 +708,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -705,7 +708,7 @@ class GeminiOptimizer(OptimizerWrapper):
ret_val = torch.zeros( ret_val = torch.zeros(
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
) )
if is_dtensor: if is_dtensor:
value = torch.reshape(value, global_shape) value = torch.reshape(value, global_shape)
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
......
...@@ -12,6 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -12,6 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
import colossalai.utils.device as device_utils
from colossalai.amp.naive_amp.mixed_precision_mixin import ( from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin, BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin, FP16MixedPrecisionMixin,
...@@ -22,7 +23,7 @@ from colossalai.logging import get_dist_logger ...@@ -22,7 +23,7 @@ from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup # from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore from .bookkeeping import BucketStore, GradientStore, ParameterStore
...@@ -182,7 +183,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -182,7 +183,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# intialize communication stream for # intialize communication stream for
# communication-compuation overlapping # communication-compuation overlapping
if self._overlap_communication: if self._overlap_communication:
self._comm_stream = torch.cuda.Stream() self._comm_stream = device_utils.Stream()
# reduction hook is only used if overlapping communication # reduction hook is only used if overlapping communication
# or stage 2 is used # or stage 2 is used
...@@ -216,7 +217,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -216,7 +217,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return len(self._working_param_groups) return len(self._working_param_groups)
def _sanity_checks(self): def _sanity_checks(self):
assert torch.cuda.is_available(), "CUDA is required" assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
for param_group in self.optim.param_groups: for param_group in self.optim.param_groups:
group_params = param_group["params"] group_params = param_group["params"]
for param in group_params: for param in group_params:
...@@ -339,11 +340,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -339,11 +340,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(moe_grad_list) > 0: if len(moe_grad_list) > 0:
moe_flat_grads.record_stream(stream) moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing # waiting for ops in the default stream finishing
stream.wait_stream(torch.cuda.current_stream()) stream.wait_stream(device_utils.current_stream())
else: else:
stream = torch.cuda.current_stream() stream = device_utils.current_stream()
with torch.cuda.stream(stream): with device_utils.stream(stream):
group_id = self._bucket_store.current_group_id group_id = self._bucket_store.current_group_id
if self.moe_extra_dp_pg is None: if self.moe_extra_dp_pg is None:
...@@ -485,7 +486,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -485,7 +486,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
torch.cuda.synchronize() device_utils.synchronize()
self.zero_grad() self.zero_grad()
...@@ -504,7 +505,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -504,7 +505,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
torch.cuda.synchronize() device_utils.synchronize()
self.zero_grad() self.zero_grad()
...@@ -620,22 +621,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -620,22 +621,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
release_param_grad(self._master_param_groups_of_current_rank[group_id]) release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank # update working partition updated by the current rank
device = get_current_device()
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"] master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param): for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
all_splited_param = [ all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self.moe_extra_dp_pg_size) for _ in range(self.moe_extra_dp_pg_size)
] ]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg) dist.all_gather(
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
)
else: else:
all_splited_param = [ all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size) for _ in range(self._world_size)
] ]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
...@@ -657,7 +661,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -657,7 +661,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
norm_type = float(norm_type) norm_type = float(norm_type)
if norm_type == inf: if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients) total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
...@@ -668,7 +672,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -668,7 +672,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
)
torch.distributed.all_reduce( torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
) )
...@@ -759,6 +765,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -759,6 +765,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
Dict: the pytorch form state_dict Dict: the pytorch form state_dict
""" """
zero_state = dict() zero_state = dict()
device = get_current_device()
for param, state in self.optim.state.items(): for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state) zero_state[param] = copy.deepcopy(state)
for k, v in state.items(): for k, v in state.items():
...@@ -766,14 +773,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -766,14 +773,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = self._param_store.master_to_working_param[id(param)] working_param = self._param_store.master_to_working_param[id(param)]
if self.moe_extra_dp_pg is not None and is_moe_tensor(v): if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
gather_tensor = [ gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
] ]
dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg) dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
else: else:
gather_tensor = [ gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
] ]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
param_state = ( param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
) )
...@@ -820,6 +827,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -820,6 +827,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block = dict() ret_block = dict()
ret_block_size = 0 ret_block_size = 0
device = get_current_device()
local_states = self.optim.state_dict()["state"] local_states = self.optim.state_dict()["state"]
for param_idx, states in local_states.items(): for param_idx, states in local_states.items():
current_block_size = 0 current_block_size = 0
...@@ -836,14 +844,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ...@@ -836,14 +844,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
if self.moe_extra_dp_pg is not None and is_moe_tensor(v): if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
state_tensor = [ state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
] ]
dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg) dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
else: else:
state_tensor = [ state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
] ]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
state_tensor = ( state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
) )
......
...@@ -13,6 +13,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig ...@@ -13,6 +13,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai import colossalai
import colossalai.utils.device as device_utils
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
...@@ -194,7 +195,7 @@ def main(): ...@@ -194,7 +195,7 @@ def main():
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master( coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
) )
...@@ -220,7 +221,7 @@ def main(): ...@@ -220,7 +221,7 @@ def main():
performance_evaluator.on_step_end(**batch) performance_evaluator.on_step_end(**batch)
performance_evaluator.on_fit_end() performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -5,7 +5,9 @@ import torch ...@@ -5,7 +5,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
import colossalai.utils.device as device_utils
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils.device import get_current_device
def divide(x: float, y: float) -> float: def divide(x: float, y: float) -> float:
...@@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float: ...@@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float: def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1: if world_size == 1:
return x return x
tensor = torch.tensor([x], device=torch.cuda.current_device()) tensor = torch.tensor([x], device=get_current_device())
dist.all_reduce(tensor) dist.all_reduce(tensor)
tensor = tensor / world_size tensor = tensor / world_size
return tensor.item() return tensor.item()
...@@ -84,13 +86,13 @@ class PerformanceEvaluator: ...@@ -84,13 +86,13 @@ class PerformanceEvaluator:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable: if self.disable:
return return
torch.cuda.synchronize() device_utils.synchronize()
self.timer.start() self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None: def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable: if self.disable:
return return
torch.cuda.synchronize() device_utils.synchronize()
self.timer.end() self.timer.end()
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape
......
from .arm_cpu_adam import ArmCPUAdamBuilder
from .cpu_adam import CPUAdamBuilder from .cpu_adam import CPUAdamBuilder
from .fused_optim import FusedOptimBuilder from .fused_optim import FusedOptimBuilder
from .layernorm import LayerNormBuilder from .layernorm import LayerNormBuilder
...@@ -29,4 +30,5 @@ __all__ = [ ...@@ -29,4 +30,5 @@ __all__ = [
"MultiTensorLambBuilder", "MultiTensorLambBuilder",
"MultiTensorScaleBuilder", "MultiTensorScaleBuilder",
"MultiTensorL2NormBuilder", "MultiTensorL2NormBuilder",
"ArmCPUAdamBuilder",
] ]
from .builder import Builder
class ArmCPUAdamBuilder(Builder):
NAME = "arm_cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
ext_type = "cpu"
def __init__(self):
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes")]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []
...@@ -7,7 +7,7 @@ import os ...@@ -7,7 +7,7 @@ import os
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional, Union
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
...@@ -21,6 +21,8 @@ class Builder(ABC): ...@@ -21,6 +21,8 @@ class Builder(ABC):
prebuilt_import_path (str): the path where the extension is installed during pip install prebuilt_import_path (str): the path where the extension is installed during pip install
""" """
ext_type: str = "cuda"
def __init__(self, name: str, prebuilt_import_path: str): def __init__(self, name: str, prebuilt_import_path: str):
self.name = name self.name = name
self.prebuilt_import_path = prebuilt_import_path self.prebuilt_import_path = prebuilt_import_path
...@@ -165,7 +167,8 @@ class Builder(ABC): ...@@ -165,7 +167,8 @@ class Builder(ABC):
) )
except ImportError: except ImportError:
# check environment # check environment
self.check_runtime_build_environment() if self.ext_type == "cuda":
self.check_runtime_build_environment()
# time the kernel compilation # time the kernel compilation
start_build = time.time() start_build = time.time()
...@@ -208,11 +211,19 @@ class Builder(ABC): ...@@ -208,11 +211,19 @@ class Builder(ABC):
return op_module return op_module
def builder(self) -> "CUDAExtension": def builder(self) -> Union["CUDAExtension", "CppExtension"]:
""" """
get a CUDAExtension instance used for setup.py get a CUDAExtension instance used for setup.py
""" """
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension
if self.ext_type == "cpp":
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
return CUDAExtension( return CUDAExtension(
name=self.prebuilt_import_path, name=self.prebuilt_import_path,
......
...@@ -2,11 +2,14 @@ from typing import Optional ...@@ -2,11 +2,14 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.optim import Adam
import colossalai import colossalai
import colossalai.utils.device as device_utils
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.nn.optimizer import HybridAdam
# from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
...@@ -19,16 +22,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"] ...@@ -19,16 +22,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
device = device_utils.get_current_device()
try: try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = model_fn() model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = Adam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
data = data_gen_fn() data = data_gen_fn()
data = { data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
} }
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
...@@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): ...@@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
continue continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache() device_utils.empty_cache()
if err is None: if err is None:
passed_models.append(name) passed_models.append(name)
...@@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): ...@@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True): def test_low_level_zero_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop) spawn(run_dist, 2, early_stop=early_stop)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment