Unverified Commit 8823cc48 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5310 from hpcaitech/feature/npu

Feature/npu
parents bce9499e 73f4dc57
......@@ -77,9 +77,9 @@ class FusedLAMB(torch.optim.Optimizer):
)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
# Skip buffer
......
......@@ -72,9 +72,9 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
# Skip buffer
self._dummy_overflow_buf = torch.tensor(
......
......@@ -2,7 +2,7 @@ from typing import Any, Optional
import torch
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.utils import multi_tensor_applier
from .cpu_adam import CPUAdam
......@@ -85,7 +85,7 @@ class HybridAdam(CPUAdam):
nvme_offload_dir,
)
if torch.cuda.is_available():
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
......
......@@ -7,10 +7,10 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule
......@@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule):
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def _prepare_inputs_for_interval_stage(self):
"""
......
......@@ -6,10 +6,11 @@ import torch.cuda
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from colossalai.utils import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
......@@ -100,7 +101,7 @@ class InterleavedSchedule(PipelineSchedule):
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
......
......@@ -6,10 +6,11 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from colossalai.utils import get_current_device
from ._utils import (
detach,
......@@ -110,7 +111,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
......@@ -317,7 +318,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
accum_loss = None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.scalar_tensor(0, device=get_current_device())
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
for _ in range(self.num_microbatches):
......
......@@ -6,7 +6,8 @@ import torch.distributed as dist
from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
from colossalai.accelerator import get_accelerator
class SeqParallelUtils:
......@@ -109,10 +110,10 @@ class Randomizer:
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
device_original_rng_state = get_rng_state()
manual_seed(seed)
self.device_rng_state = get_rng_state()
set_rng_state(device_original_rng_state)
device_original_rng_state = get_accelerator().get_rng_state()
get_accelerator().manual_seed(seed)
self.device_rng_state = get_accelerator().get_rng_state()
get_accelerator().set_rng_state(device_original_rng_state)
# to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state()
......@@ -121,10 +122,10 @@ class Randomizer:
torch.set_rng_state(cpu_original_rng_state)
def _set_device_rng_state(self, rng_state):
set_rng_state(rng_state)
get_accelerator().set_rng_state(rng_state)
def _get_device_rng_state(self):
current_state = get_rng_state()
current_state = get_accelerator().get_rng_state()
return current_state
def _set_cpu_rng_state(self, rng_state):
......@@ -209,7 +210,7 @@ class Randomizer:
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
......@@ -231,7 +232,7 @@ class Randomizer:
if dist.is_initialized():
# convert the index to tensor
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
......
......@@ -62,7 +62,7 @@ def forward_fn():
def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel.cuda_native import ColoAttention
from colossalai.nn.layer.colo_attention import ColoAttention
def forward(
self: Blip2Attention,
......
......@@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
def get_flash_core_attention_forward():
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention
......
......@@ -719,7 +719,7 @@ class GPT2PipelineForwards:
def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size):
"""
......
......@@ -530,7 +530,7 @@ class GPTJPipelineForwards:
def get_gptj_flash_attention_forward():
from transformers.models.gptj.modeling_gptj import GPTJAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_attention_heads, attn_head_size, rotary):
"""
......
......@@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
......@@ -15,14 +14,17 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
......@@ -203,7 +205,7 @@ class LlamaPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None
shard_config: ShardConfig = None,
):
r"""
Args:
......@@ -279,12 +281,13 @@ class LlamaPipelineForwards:
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
......@@ -417,7 +420,7 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
llama_version = 2
try:
......@@ -480,7 +483,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type,
origin_attn_mask=attention_mask,
)
attn_output = self.o_proj(attn_output)
......@@ -573,12 +581,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
if shard_config.enable_tensor_parallelism:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
......@@ -590,4 +599,5 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return forward
......@@ -6,7 +6,7 @@ import torch
def get_mistral_flash_attention_forward():
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def forward(
self: MistralAttention,
......
......@@ -514,7 +514,7 @@ class OPTPipelineForwards:
def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def forward(
self: OPTAttention,
......
......@@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.kernel.cuda_native import ColoAttention
from colossalai.nn.layer.colo_attention import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
......
......@@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
......
......@@ -9,7 +9,8 @@ from typing import Any, Callable, List
import torch
import torch.multiprocessing as mp
from packaging import version
from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
from colossalai.accelerator import get_accelerator
def parameterize(argument: str, values: List[Any]) -> Callable:
......@@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
def _wrap_func(f):
def _execute_by_gpu_num(*args, **kwargs):
num_avail_gpu = device_count()
num_avail_gpu = get_accelerator().device_count()
if num_avail_gpu >= min_gpus:
f(*args, **kwargs)
......@@ -263,11 +264,11 @@ def clear_cache_before_run():
def _wrap_func(f):
def _clear_cache(*args, **kwargs):
empty_cache()
reset_peak_memory_stats()
reset_max_memory_allocated()
reset_max_memory_cached()
synchronize()
get_accelerator().empty_cache()
get_accelerator().reset_peak_memory_stats()
get_accelerator().reset_max_memory_allocated()
get_accelerator().reset_max_memory_cached()
get_accelerator().synchronize()
gc.collect()
f(*args, **kwargs)
......
......@@ -4,20 +4,16 @@ from .common import (
disposable,
ensure_path_exists,
free_storage,
get_current_device,
is_ddp_ignored,
set_seed,
)
from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
from .multi_tensor_apply import multi_tensor_applier
from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer
__all__ = [
"conditional_context",
"get_current_device",
"synchronize",
"empty_cache",
"set_to_cuda",
"Timer",
"MultiTimer",
"multi_tensor_applier",
......@@ -27,7 +23,6 @@ __all__ = [
"_cast_float",
"free_storage",
"set_seed",
"get_current_device",
"is_ddp_ignored",
"set_device",
"IS_NPU_AVAILABLE",
]
......@@ -10,6 +10,15 @@ from typing import Callable
import numpy as np
import torch
from colossalai.accelerator import get_accelerator
def get_current_device():
"""
A wrapper function for accelerator's API for backward compatibility.
"""
return get_accelerator().get_current_device()
def ensure_path_exists(filename: str):
# ensure the path exists
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple, Callable
import torch
import torch.distributed as dist
IS_NPU_AVAILABLE: bool = False
try:
import torch_npu # noqa
IS_NPU_AVAILABLE = torch.npu.is_available()
except ImportError:
pass
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif IS_NPU_AVAILABLE:
return torch.device(f"npu:{torch.npu.current_device()}")
else:
return torch.device("cpu")
def _dispatch_device_func(fn_name: str, *args, **kwargs):
if torch.cuda.is_available():
return getattr(torch.cuda, fn_name)(*args, **kwargs)
elif IS_NPU_AVAILABLE:
return getattr(torch.npu, fn_name)(*args, **kwargs)
else:
raise RuntimeError("No device available")
# device semantics
def can_device_access_peer(device, peer_device) -> bool:
return _dispatch_device_func("can_device_access_peer", device, peer_device)
def current_device() -> int:
return _dispatch_device_func("current_device")
def current_stream(device=None):
return _dispatch_device_func("current_stream", device)
def default_stream(device=None):
return _dispatch_device_func("default_stream", device)
def device_count() -> int:
return _dispatch_device_func("device_count")
def get_device_capability(device=None) -> Tuple[int, int]:
return _dispatch_device_func("get_device_capability", device)
def get_device_name(device=None) -> str:
return _dispatch_device_func("get_device_name", device)
def get_device_properties(device):
return _dispatch_device_func("get_device_properties", device)
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % device_count()
_dispatch_device_func("set_device", index)
def set_stream(stream_):
return _dispatch_device_func("set_stream", stream_)
def stream(stream_):
return _dispatch_device_func("stream", stream_)
def synchronize():
return _dispatch_device_func("synchronize")
def utilization(device=None) -> int:
return _dispatch_device_func("utilization", device)
# random number generator
def get_rng_state(device="cuda") -> torch.Tensor:
return _dispatch_device_func("get_rng_state", device)
def get_rng_state_all() -> List[torch.Tensor]:
return _dispatch_device_func("get_rng_state_all")
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
return _dispatch_device_func("set_rng_state", new_state, device)
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
return _dispatch_device_func("set_rng_state_all", new_states)
def manual_seed(seed: int) -> None:
return _dispatch_device_func("manual_seed", seed)
def manual_seed_all(seed: int) -> None:
return _dispatch_device_func("manual_seed_all", seed)
def seed() -> None:
return _dispatch_device_func("seed")
def seed_all() -> None:
return _dispatch_device_func("seed_all")
def initial_seed() -> int:
return _dispatch_device_func("initial_seed")
# streams and events
def Stream(device=None, priority=0, **kwargs):
return _dispatch_device_func("Stream", device, priority, **kwargs)
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
# memory management
def empty_cache() -> None:
return _dispatch_device_func("empty_cache")
def memory_stats(device=None) -> Dict[str, Any]:
return _dispatch_device_func("memory_stats", device)
def memory_summary(device=None, abbreviated=False) -> str:
return _dispatch_device_func("memory_summary", device, abbreviated)
def memory_snapshot():
return _dispatch_device_func("memory_snapshot")
def memory_allocated(device=None) -> int:
return _dispatch_device_func("memory_allocated", device)
def max_memory_allocated(device=None) -> int:
return _dispatch_device_func("max_memory_allocated", device)
def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)
def reset_max_memory_cached(device=None) -> None:
return _dispatch_device_func("reset_max_memory_cached", device)
def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)
def max_memory_reserved(device=None) -> int:
return _dispatch_device_func("max_memory_reserved", device)
def set_per_process_memory_fraction(fraction: float, device=None) -> None:
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)
# amp
def autocast() -> Callable:
if torch.cuda.is_available():
return torch.cuda.amp.autocast()
elif IS_NPU_AVAILABLE:
return torch.npu.amp.autocast()
else:
raise RuntimeError("No device available")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment