Unverified Commit e5ce4c8e authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
parent 8d56c9c3
......@@ -10,7 +10,7 @@ from torch.utils._pytree import tree_map
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule
......
......@@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
......
......@@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from ._utils import (
detach,
......
......@@ -2,16 +2,19 @@
# -*- encoding: utf-8 -*-
import warnings
from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.lazy import LazyInitContext
from ._operation import hook_paramter_in_backward
from ._operation import hook_paramter_in_backward
from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
EnableFastLayerNorm = True
except ImportError:
EnableFastLayerNorm = False
......@@ -19,10 +22,27 @@ except ImportError:
try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
except ImportError:
warnings.warn(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
)
warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
......@@ -52,6 +72,7 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
]
if EnableFastLayerNorm:
class FastLayerNormWithHook(FastLayerNorm):
def __init__(self, hidden_size, eps=0.00001):
super().__init__(hidden_size, eps)
......@@ -61,24 +82,6 @@ if EnableFastLayerNorm:
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
class BaseLayerNorm(ABC):
@abstractmethod
......@@ -244,6 +247,7 @@ class FusedRMSNorm(BaseLayerNorm):
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. "
......@@ -264,7 +268,7 @@ class FusedRMSNorm(BaseLayerNorm):
nn.Module: FusedRMSNorm module.
"""
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
pass
except ImportError:
raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
......@@ -282,7 +286,9 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.eps
elementwise_affine = module.elementwise_affine
rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
)
rmsnorm.weight = module.weight
......
......@@ -7,7 +7,7 @@ from .common import (
is_ddp_ignored,
set_seed,
)
from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize
from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
from .multi_tensor_apply import multi_tensor_applier
from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer
......@@ -29,4 +29,5 @@ __all__ = [
"set_seed",
"is_ddp_ignored",
"set_device",
"IS_NPU_AVAILABLE",
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
import torch
import torch.distributed as dist
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
else:
return torch.device("cpu")
def synchronize():
"""Similar to cuda.synchronize().
Waits for all kernels in all streams on a CUDA device to complete.
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
def empty_cache():
"""Similar to cuda.empty_cache()
Releases all unoccupied cached memory currently held by the caching allocator.
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % torch.cuda.device_count()
torch.cuda.set_device(index)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
IS_NPU_AVAILABLE: bool = False
try:
import torch_npu # noqa
IS_NPU_AVAILABLE = torch.npu.is_available()
except ImportError:
pass
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif IS_NPU_AVAILABLE:
return torch.device(f"npu:{torch.npu.current_device()}")
else:
return torch.device("cpu")
def _dispatch_device_func(fn_name: str, *args, **kwargs):
if torch.cuda.is_available():
return getattr(torch.cuda, fn_name)(*args, **kwargs)
elif IS_NPU_AVAILABLE:
return getattr(torch.npu, fn_name)(*args, **kwargs)
else:
raise RuntimeError("No device available")
# device semantics
def can_device_access_peer(device, peer_device) -> bool:
return _dispatch_device_func("can_device_access_peer", device, peer_device)
def current_device() -> int:
return _dispatch_device_func("current_device")
def current_stream(device=None):
return _dispatch_device_func("current_stream", device)
def default_stream(device=None):
return _dispatch_device_func("default_stream", device)
def device_count() -> int:
return _dispatch_device_func("device_count")
def get_device_capability(device=None) -> Tuple[int, int]:
return _dispatch_device_func("get_device_capability", device)
def get_device_name(device=None) -> str:
return _dispatch_device_func("get_device_name", device)
def get_device_properties(device):
return _dispatch_device_func("get_device_properties", device)
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % device_count()
_dispatch_device_func("set_device", index)
def set_stream(stream_):
return _dispatch_device_func("set_stream", stream_)
def stream(stream_):
return _dispatch_device_func("stream", stream_)
def synchronize():
return _dispatch_device_func("synchronize")
def utilization(device=None) -> int:
return _dispatch_device_func("utilization", device)
# random number generator
def get_rng_state(device="cuda") -> torch.Tensor:
return _dispatch_device_func("get_rng_state", device)
def get_rng_state_all() -> List[torch.Tensor]:
return _dispatch_device_func("get_rng_state_all")
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
return _dispatch_device_func("set_rng_state", new_state, device)
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
return _dispatch_device_func("set_rng_state_all", new_states)
def manual_seed(seed: int) -> None:
return _dispatch_device_func("manual_seed", seed)
def manual_seed_all(seed: int) -> None:
return _dispatch_device_func("manual_seed_all", seed)
def seed() -> None:
return _dispatch_device_func("seed")
def seed_all() -> None:
return _dispatch_device_func("seed_all")
def initial_seed() -> int:
return _dispatch_device_func("initial_seed")
# streams and events
def Stream(device=None, priority=0, **kwargs):
return _dispatch_device_func("Stream", device, priority, **kwargs)
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
# memory management
def empty_cache() -> None:
return _dispatch_device_func("empty_cache")
def memory_stats(device=None) -> Dict[str, Any]:
return _dispatch_device_func("memory_stats", device)
def memory_summary(device=None, abbreviated=False) -> str:
return _dispatch_device_func("memory_summary", device, abbreviated)
def memory_snapshot():
return _dispatch_device_func("memory_snapshot")
def memory_allocated(device=None) -> int:
return _dispatch_device_func("memory_allocated", device)
def max_memory_allocated(device=None) -> int:
return _dispatch_device_func("max_memory_allocated", device)
def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)
def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)
def max_memory_reserved(device=None) -> int:
return _dispatch_device_func("max_memory_reserved", device)
def set_per_process_memory_fraction(fraction: float, device=None) -> None:
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)
......@@ -3,7 +3,7 @@
import time
from typing import Tuple
from .cuda import synchronize
from .device import synchronize
class Timer:
......
......@@ -7,6 +7,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
class TensorState(Enum):
......@@ -172,7 +173,7 @@ class Chunk:
if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == "cuda":
if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu":
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
......@@ -191,10 +192,8 @@ class Chunk:
if self.chunk_temp is not None:
return self.chunk_temp.device.type
else:
if self.is_gathered:
return "cuda"
elif self.cuda_shard is not None:
return "cuda"
if self.is_gathered or self.cuda_shard is not None:
return "npu" if IS_NPU_AVAILABLE else "cuda"
else:
return "cpu"
......@@ -329,12 +328,12 @@ class Chunk:
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == "cuda", "each chunk should first be moved to CUDA"
assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return
if device.type == "cuda":
if device.type == "cuda" or device.type == "npu":
assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
......@@ -484,7 +483,7 @@ class Chunk:
assert friend_chunk.is_gathered is True
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True
elif friend_chunk.device_type == "cuda" and self.device_type == "cuda":
elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"):
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
......
......@@ -206,7 +206,10 @@ class ChunkManager:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
device_type = tensor.device.type
if device_type == "npu":
device_type = "cuda"
self.total_mem[device_type] += tensor.numel() * tensor.element_size()
def __repr__(self) -> str:
msg = [
......
......@@ -10,32 +10,30 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from colossalai.checkpoint_io.utils import gather_distributed_param
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh,
get_global_shape,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor,
is_distributed_tensor,
get_global_shape,
init_as_dtensor
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
......@@ -162,7 +160,7 @@ class GeminiDDP(ModelWrapper):
self._init_chunks(
param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != "cuda",
cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0),
pin_memory=pin_memory,
)
super().__init__(module)
......@@ -453,12 +451,13 @@ class GeminiDDP(ModelWrapper):
global_shape = get_global_shape(tensor)
device_mesh = get_device_mesh(tensor)
shard_spec = get_sharding_spec(tensor)
record_tensor = init_as_dtensor(record_tensor,
device_mesh=device_mesh,
sharding_spec=shard_spec,
global_shape = global_shape)
record_tensor = init_as_dtensor(
record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape
)
elif is_customized_distributed_tensor(tensor):
init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn)
init_tensor_as_customization_distributed(
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
)
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
assert tensor not in chunk_to_save_data
......@@ -634,7 +633,15 @@ class GeminiDDP(ModelWrapper):
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None):
def load(
param_name,
dest_tensor,
copy_func,
source_device_mesh=None,
source_sharding_spec=None,
shard_fn=None,
gather_fn=None,
):
state_key = prefix + param_name
if state_key in state_dict:
input_param = state_dict[state_key]
......@@ -642,7 +649,9 @@ class GeminiDDP(ModelWrapper):
if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None:
input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn)
input_param = distribute_tensor_with_customization(
input_param, shard_fn=shard_fn, gather_fn=gather_fn
)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
......@@ -687,7 +696,6 @@ class GeminiDDP(ModelWrapper):
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items():
source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
if is_distributed_tensor(tensor):
# shard the input param
......@@ -699,7 +707,15 @@ class GeminiDDP(ModelWrapper):
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn)
load(
parameter_name,
tensor,
partial(load_parameter, parameter_slice),
source_device_mesh,
source_sharding_spec,
shard_fn,
gather_fn,
)
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
......@@ -799,7 +815,7 @@ class GeminiDDP(ModelWrapper):
for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor):
buffer.materialize()
buffer.data = buffer.cuda()
buffer.data = buffer.to(get_current_device())
if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision)
......
......@@ -17,9 +17,7 @@ class GeminiManager:
https://arxiv.org/abs/2108.05818
Args:
placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
placement_policy (str): Which device to place *held* tensors. It can be 'static' and 'auto'.
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
......@@ -121,7 +119,7 @@ class GeminiManager:
start = time()
cuda_demand = 0
for chunk in chunks:
if chunk.device_type == "cuda":
if chunk.device_type == "cuda" or chunk.device_type == "npu":
if chunk.is_gathered:
pass
else:
......
......@@ -7,31 +7,29 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.distributed import ProcessGroup
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor,
is_distributed_tensor,
get_global_shape,
init_as_dtensor
)
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
......@@ -312,7 +310,7 @@ class GeminiOptimizer(OptimizerWrapper):
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda":
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
......@@ -326,7 +324,7 @@ class GeminiOptimizer(OptimizerWrapper):
for fake_param in group["params"]:
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda":
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
state = self.optim.state[fake_param]
for k, v in state.items():
if isinstance(v, torch.Tensor):
......@@ -479,13 +477,17 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
if is_dtensor:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor(state_tensor,
state_tensor = init_as_dtensor(
state_tensor,
device_mesh=device_mesh,
sharding_spec=shard_spec,
global_shape = global_shape)
global_shape=global_shape,
)
elif is_customized_distributed:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
collected_states[state_name] = state_tensor.reshape(global_shape)
......@@ -533,13 +535,14 @@ class GeminiOptimizer(OptimizerWrapper):
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor:
state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor(state_tensor,
sharding_spec=shard_spec,
device_mesh=device_mesh,
global_shape=global_shape)
state_tensor = init_as_dtensor(
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
)
elif is_customized_distributed:
state_tensor = state_tensor.to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
return collected_states
......@@ -548,7 +551,7 @@ class GeminiOptimizer(OptimizerWrapper):
self,
param_id: int,
state_names: list,
device: torch.device = torch.device("cuda"),
device: torch.device = get_current_device(),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
......
......@@ -12,6 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
import colossalai.utils.device as device_utils
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
......@@ -22,7 +23,7 @@ from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore
......@@ -182,7 +183,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
self._comm_stream = device_utils.Stream()
# reduction hook is only used if overlapping communication
# or stage 2 is used
......@@ -216,7 +217,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return len(self._working_param_groups)
def _sanity_checks(self):
assert torch.cuda.is_available(), "CUDA is required"
assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
for param_group in self.optim.param_groups:
group_params = param_group["params"]
for param in group_params:
......@@ -339,11 +340,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(moe_grad_list) > 0:
moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
stream.wait_stream(torch.cuda.current_stream())
stream.wait_stream(device_utils.current_stream())
else:
stream = torch.cuda.current_stream()
stream = device_utils.current_stream()
with torch.cuda.stream(stream):
with device_utils.stream(stream):
group_id = self._bucket_store.current_group_id
if self.moe_extra_dp_pg is None:
......@@ -485,7 +486,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
device_utils.synchronize()
self.zero_grad()
......@@ -504,7 +505,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
device_utils.synchronize()
self.zero_grad()
......@@ -620,22 +621,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
device = get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg)
dist.all_gather(
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
)
else:
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype)
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
......@@ -657,7 +661,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
......@@ -668,7 +672,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated
# Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
)
......@@ -759,6 +765,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
Dict: the pytorch form state_dict
"""
zero_state = dict()
device = get_current_device()
for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
......@@ -766,14 +773,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = self._param_store.master_to_working_param[id(param)]
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg)
dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
else:
gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
......@@ -820,6 +827,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block = dict()
ret_block_size = 0
device = get_current_device()
local_states = self.optim.state_dict()["state"]
for param_idx, states in local_states.items():
current_block_size = 0
......@@ -836,14 +844,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if isinstance(v, torch.Tensor) and k != "step":
if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg)
dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
else:
state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
......
......@@ -13,6 +13,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai
import colossalai.utils.device as device_utils
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
from colossalai.cluster import DistCoordinator
......@@ -194,7 +195,7 @@ def main():
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
)
......@@ -220,7 +221,7 @@ def main():
performance_evaluator.on_step_end(**batch)
performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
......
......@@ -5,7 +5,9 @@ import torch
import torch.distributed as dist
from torch import Tensor
import colossalai.utils.device as device_utils
from colossalai.cluster import DistCoordinator
from colossalai.utils.device import get_current_device
def divide(x: float, y: float) -> float:
......@@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
tensor = torch.tensor([x], device=torch.cuda.current_device())
tensor = torch.tensor([x], device=get_current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
......@@ -84,13 +86,13 @@ class PerformanceEvaluator:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
torch.cuda.synchronize()
device_utils.synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
torch.cuda.synchronize()
device_utils.synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape
......
from .arm_cpu_adam import ArmCPUAdamBuilder
from .cpu_adam import CPUAdamBuilder
from .fused_optim import FusedOptimBuilder
from .layernorm import LayerNormBuilder
......@@ -29,4 +30,5 @@ __all__ = [
"MultiTensorLambBuilder",
"MultiTensorScaleBuilder",
"MultiTensorL2NormBuilder",
"ArmCPUAdamBuilder",
]
from .builder import Builder
class ArmCPUAdamBuilder(Builder):
NAME = "arm_cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
ext_type = "cpu"
def __init__(self):
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes")]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []
......@@ -7,7 +7,7 @@ import os
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Union
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
......@@ -21,6 +21,8 @@ class Builder(ABC):
prebuilt_import_path (str): the path where the extension is installed during pip install
"""
ext_type: str = "cuda"
def __init__(self, name: str, prebuilt_import_path: str):
self.name = name
self.prebuilt_import_path = prebuilt_import_path
......@@ -165,6 +167,7 @@ class Builder(ABC):
)
except ImportError:
# check environment
if self.ext_type == "cuda":
self.check_runtime_build_environment()
# time the kernel compilation
......@@ -208,11 +211,19 @@ class Builder(ABC):
return op_module
def builder(self) -> "CUDAExtension":
def builder(self) -> Union["CUDAExtension", "CppExtension"]:
"""
get a CUDAExtension instance used for setup.py
"""
from torch.utils.cpp_extension import CUDAExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension
if self.ext_type == "cpp":
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
return CUDAExtension(
name=self.prebuilt_import_path,
......
......@@ -2,11 +2,14 @@ from typing import Optional
import torch
import torch.distributed as dist
from torch.optim import Adam
import colossalai
import colossalai.utils.device as device_utils
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.nn.optimizer import HybridAdam
# from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
......@@ -19,16 +22,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
device = device_utils.get_current_device()
try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
......@@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
device_utils.empty_cache()
if err is None:
passed_models.append(name)
......@@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
spawn(run_dist, 2, early_stop=early_stop)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment