"tests/test_legacy/test_engine/test_gradient_accumluation.py" did not exist on "5a1a095b925a78f00185954bb3783c73086dbade"
Commit 08f2920e authored by zhuwenwen's avatar zhuwenwen
Browse files

init colossalai, support dtk2304

parent da3f0934
Pipeline #237 failed with stages
in 0 seconds
from ._param_hookmgr import BaseParamHookMgr
__all__ = ["BaseParamHookMgr"]
from typing import Callable, List
import torch
import functools
class BaseParamHookMgr(object):
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
r"""
register backward hook on every parameters of module
"""
self._param_list = param_list
self._hook_list = []
def register_backward_hooks(self, hook_call: Callable) -> None:
r"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
The hook should have the following signature:
```
hook(param, grad) -> Tensor or None
```
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
for p in self._param_list:
if p.requires_grad and not hasattr(p, '_base_param_hook'):
handle = p.register_hook(functools.partial(hook_call, p))
p._base_param_hook = handle
def remove_hooks(self) -> None:
"""
Remove hooks from model parameters.
"""
for p in self._param_list:
if p.requires_grad and hasattr(p, '_base_param_hook'):
p._base_param_hook.remove()
import functools
from abc import ABC, abstractmethod
from time import time
from typing import Dict, List, Optional, Tuple, Type
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
@abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
raise NotImplementedError
@staticmethod
def get_default_device() -> torch.device:
return torch.device('cpu')
class CPUPlacementPolicy(PlacementPolicy):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
volume = 0
start = time()
for chunk in can_evict_chunks:
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
volume += chunk.chunk_mem
return volume, time() - start
class CUDAPlacementPolicy(PlacementPolicy):
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
return 0, 0
@staticmethod
def get_default_device() -> torch.device:
return get_current_device()
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
_warmup_non_model_data_ratio: float = 0.8
_steady_cuda_cap_ratio: float = 0.9
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs) -> Tuple[int, float]:
"""
Evict tensors from CUDA device.
Args:
can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted.
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
if avail_cuda_model_data < cuda_demand:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_chunks = can_evict_chunks
if not warmup:
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
# print(self._sort_can_evict_chunks.cache_info())
for chunk in to_free_chunks:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
freed_cuda_model_data += chunk.chunk_mem
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}")
return freed_cuda_model_data, time() - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
@staticmethod
def set_warmup_non_model_data_ratio(ratio: float) -> None:
ratio = float(ratio)
assert 0.0 < ratio < 1.0
AutoPlacementPolicy._warmup_non_model_data_ratio = ratio
@staticmethod
def set_steady_cuda_cap_ratio(ratio: float) -> None:
ratio = float(ratio)
assert 0.0 < ratio < 1.0
AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = False
_accessed_memory_boundary = 512 * 1024**2
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs) -> Tuple[int, float]:
"""
See the docstrings in the class `AutoPlacementPolicy`.
"""
start = time()
used_accessed_memory = self.chunk_manager.accessed_mem
avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
freed_accessed_memory = 0
if avail_accessed_memory < cuda_demand:
to_free_memory = cuda_demand - avail_accessed_memory
to_free_chunks = can_evict_chunks
if not warmup:
# sort all chunks
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
for chunk in to_free_chunks:
if freed_accessed_memory >= to_free_memory:
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
freed_accessed_memory += chunk.chunk_mem
if freed_accessed_memory < to_free_memory:
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_memory}, freed {freed_accessed_memory}")
return freed_accessed_memory, time() - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
@staticmethod
def set_const_memory_boundary(cuda_memory_mb: int) -> None:
boundary = int(cuda_memory_mb * 1024**2)
assert boundary > 0
ConstPlacementPolicy._accessed_memory_boundary = boundary
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {
'cpu': CPUPlacementPolicy,
'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy,
'const': ConstPlacementPolicy
}
@staticmethod
def create(policy_name: str) -> Type[PlacementPolicy]:
if policy_name not in PlacementPolicyFactory.policies:
raise TypeError(f"Unknown tensor placement policy {policy_name}")
return PlacementPolicyFactory.policies[policy_name]
@staticmethod
def get_polocy_names():
return tuple(PlacementPolicyFactory.policies.keys())
@staticmethod
def get_default_device(policy_name: str) -> torch.device:
policy_cls = PlacementPolicyFactory.create(policy_name)
return policy_cls.get_default_device()
from enum import Enum
from typing import Optional
import torch
from typing import Union
from colossalai.gemini.gemini_context import GeminiMemoryManager
def sizeof_tensor(tensor: torch.Tensor):
return tensor.numel() * tensor.element_size()
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
# Global Stateful Tensor Manager
GST_MGR = GeminiMemoryManager(TensorState)
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = None
self._payload_size = 0 # byte size of current payload
StatefulTensor.GST_MGR.register_new_instance()
if self._state == TensorState.FREE:
# when the state is free, payload should be None
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
else:
# otherwise, payload should not be None
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
self._payload = maybe_tensor
self._payload_size = sizeof_tensor(maybe_tensor)
self.__trans_state_update(TensorState.FREE, state)
def data_ptr(self):
if self._payload is None:
return 0 # if a tensor has no storage, 0 should be returned
return self._payload.data_ptr()
def set_null(self) -> None:
# notice that free stateful tensor do not need to become null again
if self.state != TensorState.FREE:
self.__trans_state_update(self.state, TensorState.FREE)
self.__release()
def is_null(self) -> bool:
if self.state == TensorState.FREE:
# check sanity here
assert self.payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
if self.state == TensorState.FREE:
# free stateful tensor can't change state
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
return
self.__trans_state_update(self.state, state)
if state == TensorState.FREE:
self.__release()
else:
self._state = state
def move_to(self, device: Union[torch.device, int]):
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
if not isinstance(device, torch.device):
to_device = torch.device('cuda', device)
else:
to_device = device
from_device_type = self.device.type
if from_device_type == to_device.type:
# from device == to device
return
# update manager's information
self.__trans_device_update(from_device_type, to_device.type)
self.payload.data = self.payload.data.to(to_device)
def payload_copy(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def payload_reset(self, tensor) -> None:
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
if self.payload is not None:
# release old payload
self.__trans_state_update(self.state, TensorState.FREE)
else:
# otherwise, set the state to HOLD for new payload
self._state = TensorState.HOLD
del self._payload
self._payload = tensor
self._payload_size = sizeof_tensor(tensor)
# record new payload
self.__trans_state_update(TensorState.FREE, self.state)
def payload_relay(self, rhs):
# relay the payload of rhs to current stateful tensor
# can't support null relay right now
assert not rhs.is_null()
# now this function only support stateful tensor that has zero-length payload
# because it doesn't require memory manager updating
# you can extend this function by yourself
assert self.payload_size == 0
self._payload = rhs.payload
self._payload_size = rhs.payload_size
self._state = TensorState.HOLD
self.__trans_state_update(rhs.state, TensorState.HOLD)
rhs.__release()
@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload
@property
def payload_size(self) -> int:
return self._payload_size
@property
def state(self) -> TensorState:
return self._state
@property
def device(self) -> torch.device:
return self._payload.device
@property
def dtype(self) -> torch.dtype:
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
def __release(self):
# release current payload
# shouldn't be visible to users
self._state = TensorState.FREE
self._payload = None
self._payload_size = 0
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
"""Update global manager when changing the state of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
device_type = self.device.type
if from_state != TensorState.FREE:
manager.state_mem[device_type][from_state] -= size
else:
# when from_state is FREE, the tensor is new to manager
# we should add its memory
manager.total_mem[device_type] += size
if to_state != TensorState.FREE:
manager.state_mem[device_type][to_state] += size
else:
# when to_state is FREE, the tensor will be deleted soon
# we should sub its memory
manager.total_mem[device_type] -= size
def __trans_device_update(self, from_type: str, to_type: str):
"""Update global manager when changing the device of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
state = self.state
# update aggregated information
manager.total_mem[from_type] -= size
manager.total_mem[to_type] += size
# update the information of each state
manager.state_mem[from_type][state] -= size
manager.state_mem[to_type][state] += size
def __del__(self):
self.set_null()
StatefulTensor.GST_MGR.delete_instance()
del self
import functools
import torch
import types
from colossalai.utils.cuda import get_current_device
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List
from colossalai.logging import get_dist_logger
from time import time
class StatefulTensorMgr(object):
"""
Stateful Tensor Manager, inspired from PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None:
self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy
self._stateful_tensor_list: List[StatefulTensor] = []
self._compute_list: List[StatefulTensor] = []
self._compute_idx: int = -1
self._cpu_gpu_move_volume = 0
self._layout_time = 0
self._evict_time = 0
self._warmup = True
def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None:
assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice"
self._stateful_tensor_list = tensor_list
for t in self._stateful_tensor_list:
assert isinstance(t, StatefulTensor)
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
def start_iter(self):
pass
def finish_iter(self):
"""This function must be called when each iteration finishes
"""
self._warmup = False
self._compute_idx = -1
self._cpu_gpu_move_volume = 0
self._layout_time = 0
self._evict_time = 0
def adjust_layout(self) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE]
start = time()
move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup)
self._layout_time += time() - start
vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
compute_idx=self._compute_idx)
self._cpu_gpu_move_volume += vol
self._evict_time += evict_time
# move COMPUTE tensors to CUDA
self._cpu_gpu_move_volume += cuda_demand
for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device())
@property
def cpu_gpu_move_volume(self):
return self._cpu_gpu_move_volume
def _trans_state(self, trans_state_func, stateful_tensor, state):
trans_state_func(state)
if state == TensorState.COMPUTE:
self._compute_idx += 1
if self._warmup:
self._compute_list.append(stateful_tensor)
@functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool):
move_to_cuda_tensor_list = []
hold_cuda_tensor_list = []
for tensor in self._stateful_tensor_list:
if tensor.state == TensorState.FREE:
continue
if tensor.device.type == 'cuda':
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
hold_cuda_tensor_list.append(tensor)
elif tensor.device.type == 'cpu':
if tensor.state == TensorState.COMPUTE:
move_to_cuda_tensor_list.append(tensor)
else:
raise RuntimeError
return move_to_cuda_tensor_list, hold_cuda_tensor_list
from abc import ABC, abstractmethod
from time import time
from typing import List, Optional
import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.memory_tracer import MemStatsCollector
from typing import Type
import functools
class TensorPlacementPolicy(ABC):
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
self.device: Optional[torch.device] = device
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
@abstractmethod
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
raise NotImplementedError
class CPUTensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
volume = 0
for t in hold_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, self.device)
volume += t.payload.numel() * t.payload.element_size()
return volume, 0
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
return 0, 0
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
super().__init__(None, mem_stats_collector=mem_stats_collector)
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
# TODO(ver217): make these args configurable
self._warmup_non_model_data_ratio: float = 0.8
self._steady_cuda_cap_ratio: float = 0.9
def evict_tensors(self,
hold_cuda_tensor_list: List[StatefulTensor],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: List[StatefulTensor] = [],
compute_idx: int = 0,
**kwargs) -> int:
"""
Evict tensors from CUDA device.
Args:
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device())
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda']
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
end = time()
if avail_cuda_model_data < cuda_demand:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_tensor_list = hold_cuda_tensor_list
if not warmup:
to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx,
tuple(compute_list))
# print(self._sort_hold_cuda_tensors.cache_info())
end = time()
for t in to_free_tensor_list:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
freed_cuda_model_data += t.payload_size
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
return freed_cuda_model_data, end - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensors}
for i in range(len(compute_list) - 1, compute_idx, -1):
if compute_list[i] in next_compute_idx:
next_compute_idx[compute_list[i]] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
class TensorPlacementPolicyFactory:
@staticmethod
def create(policy_name: str) -> Type[TensorPlacementPolicy]:
if policy_name == 'cpu':
return CPUTensorPlacementPolicy
elif policy_name == 'cuda':
return CUDATensorPlacementPolicy
elif policy_name == 'auto':
return AutoTensorPlacementPolicy
else:
raise TypeError(f"Unknown tensor placement policy {policy_name}")
import torch
from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if isinstance(tensor, StatefulTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
else:
return 0, 0
cuda_use, cpu_use = 0, 0
mem_use = t.storage().size() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':
cpu_use += mem_use
return cuda_use, cpu_use
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
torch.Tensor]) -> None:
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if isinstance(src_t, StatefulTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if isinstance(tgt_t, StatefulTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if isinstance(src_t, StatefulTensor):
src_t.set_null()
else:
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
int]) -> None:
"""
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
target_device: a traget device, if type is int, it the index of cuda card.
"""
if not isinstance(target_device, torch.device):
target_device = torch.device(f'cuda:{target_device}')
if isinstance(t, torch.Tensor):
t.data = t.data.to(target_device)
elif isinstance(t, StatefulTensor):
t.move_to(target_device)
else:
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
# TODO() optimize the tensor moving with non-blocking
if isinstance(t, torch.Tensor):
t.data = t.data.cpu()
elif isinstance(t, StatefulTensor):
t.move_to(torch.device('cpu'))
else:
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
# TODO() rename this function
colo_model_data_tensor_move_inline(t, target_device)
t_payload = t.payload if isinstance(t, StatefulTensor) else t
return t_payload
from typing import Optional
class TensorParallelEnv(object):
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = object.__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self, *args, **kwargs):
self.load(*args, **kwargs)
def load(self,
mode: Optional[str] = None,
vocab_parallel: bool = False,
parallel_input_1d: bool = False,
summa_dim: int = None,
tesseract_dim: int = None,
tesseract_dep: int = None,
depth_3d: int = None,
input_group_3d=None,
weight_group_3d=None,
output_group_3d=None,
input_x_weight_group_3d=None,
output_x_weight_group_3d=None):
self.mode = mode
self.vocab_parallel = vocab_parallel
self.parallel_input_1d = parallel_input_1d
self.summa_dim = summa_dim
self.tesseract_dim = tesseract_dim
self.tesseract_dep = tesseract_dep
self.depth_3d = depth_3d
self.input_group_3d = input_group_3d
self.weight_group_3d = weight_group_3d
self.output_group_3d = output_group_3d
self.input_x_weight_group_3d = input_x_weight_group_3d
self.output_x_weight_group_3d = output_x_weight_group_3d
def save(self):
return dict(mode=self.mode,
vocab_parallel=self.vocab_parallel,
parallel_input_1d=self.parallel_input_1d,
summa_dim=self.summa_dim,
tesseract_dim=self.tesseract_dim,
tesseract_dep=self.tesseract_dep,
depth_3d=self.depth_3d,
input_group_3d=self.input_group_3d,
weight_group_3d=self.weight_group_3d,
output_group_3d=self.output_group_3d,
input_x_weight_group_3d=self.input_x_weight_group_3d,
output_x_weight_group_3d=self.output_x_weight_group_3d)
tensor_parallel_env = TensorParallelEnv()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import argparse
import os
import pprint
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.logging import get_dist_logger
from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
from colossalai.engine import Engine
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
from colossalai.utils.moe import sync_moe_model_param
from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.engine.gradient_accumulation import accumulate_gradient
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.zero import convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
def get_default_parser():
"""Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
Returns:
Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host', type=str, help='the master address for distributed training')
parser.add_argument('--port', type=int, help='the master port for distributed training')
parser.add_argument('--world_size', type=int, help='world size for distributed training')
parser.add_argument('--rank', type=int, help='rank for the default process group')
parser.add_argument('--local_rank', type=int, help='local rank on the node')
parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
return parser
def launch(config: Union[str, Path, Config, Dict],
rank: int,
world_size: int,
host: str,
port: int,
backend: str = 'nccl',
local_rank: int = None,
seed: int = 1024,
verbose: bool = True):
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
rank (int): Rank for the default process group
world_size (int): World size of the default process group
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
local_rank (int, optional):
Rank for the process on the node and is used to set the default CUDA device,
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
Raises:
Exception: Raise exception when config type is wrong
"""
gpc.verbose = verbose
# set config
assert isinstance(config, (Config, str, Path, dict)), \
f'expected argument config to be Config, str or Path, but got {type(config)}'
if not isinstance(config, Config) and isinstance(config, dict):
config = Config(config)
if isinstance(config, (str, Path)):
config = Config.from_file(config)
gpc.load_config(config)
# init default process group
gpc.init_global_dist(rank, world_size, backend, host, port)
# init process groups for different parallel modes from config
gpc.init_parallel_groups()
# set cuda device
if torch.cuda.is_available():
# if local rank is not given, calculate automatically
gpc.set_device(local_rank)
# set the number of processes running on the same node
gpc.detect_num_processes_on_current_node()
gpc.set_seed(seed)
if verbose:
logger = get_dist_logger()
logger.info(
f'Distributed environment is initialized, '
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}',
ranks=[0])
def launch_from_slurm(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NPROCS'])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
)
launch(config=config,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
host (str): The master address for distributed training
port (str): The master port for distributed training
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
)
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def launch_from_torch(config: Union[str, Path, Config, Dict],
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
Args:
config (Union[str, dict, Config]): Config file or config file path are both acceptable
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
seed (int, optional): Specified random seed for every process. Defaults to 1024.
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def initialize(model: nn.Module,
optimizer: Optimizer,
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
lr_scheduler: Optional[_LRScheduler] = None,
ophooks: Optional[List[BaseOpHook]] = None,
verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
Args:
model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
Your optimizer instance.
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
verbose (bool, optional): Whether to print logs.
Returns:
Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
where only ``engine`` could not be None.
"""
# get logger
logger = get_dist_logger()
gpc.verbose = verbose
# get config from gpc
config = gpc.config
# print config
if verbose:
logger.info(
f"\n========== Your Config ========\n"
f"{pprint.pformat(gpc.config)}\n"
f"================================\n",
ranks=[0])
# cudnn
cudnn_benchmark = config.get('cudnn_benchmark', False)
cudnn_deterministic = config.get('cudnn_deterministic', False)
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.deterministic = cudnn_deterministic
if verbose:
logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
# zero
use_zero = hasattr(gpc.config, 'zero')
if use_zero:
zero_cfg = gpc.config.get('zero', None)
if zero_cfg is not None:
cfg_ = zero_cfg.copy()
else:
cfg_ = {}
optimizer_config = zero_cfg.get('optimizer_config', None)
model_config = zero_cfg.get('model_config', None)
model, optimizer = convert_to_zero_v2(model,
optimizer,
model_config=model_config,
optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
else:
if isinstance(model, nn.Module):
# first sync model across dp ranks
model.to(get_current_device())
elif isinstance(model, Callable):
model = model().to(get_current_device())
# optimizer maybe a optimizer_cls
logger.warning("Initializing an non ZeRO model with optimizer class")
if isinstance(optimizer, Callable):
optimizer = optimizer(model.parameters())
if not use_zero:
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
elif MOE_CONTEXT.is_initialized:
sync_moe_model_param(model)
elif is_using_ddp():
sync_model_param(model, ParallelMode.DATA)
else:
logger.warning(
"The parameters of models is not automatically synchronized.\n"
"Please make sure that all parameters are the same in data parallel group.",
ranks=[0])
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)
if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
raise ConfigException(
"It is not allowed to set fp16 and zero configuration in your config file at the same time")
# clip grad norm
clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
# initialize amp
amp_mode = None
if fp16_cfg is not None and fp16_cfg.mode is not None:
cfg_ = fp16_cfg.copy()
amp_mode = cfg_.pop('mode')
if is_using_pp():
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
if amp_mode == AMP_TYPE.NAIVE:
cfg_['clip_grad_norm'] = clip_grad_norm
model, optimizer, criterion = convert_to_amp(model=model,
optimizer=optimizer,
criterion=criterion,
mode=amp_mode,
amp_config=cfg_)
# get torch ddp config
torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
# gradient handler
gradient_handler_cfg = gpc.config.get('gradient_handler', None)
if gradient_handler_cfg is None:
# if gradient handler is not specified in the configuration file,
# check in the following order
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if isinstance(optimizer, ShardedOptimizerV2):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
if verbose:
logger.info(
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_ddp() and MOE_CONTEXT.is_initialized:
gradient_handler_cfg = [dict(type='MoeGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
elif is_using_sequence():
model = DDP(model,
process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
device_ids=[torch.cuda.current_device()],
**torch_ddp_cfg)
if verbose:
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
ranks=[0])
elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
model = DDP(model,
process_group=gpc.get_group(ParallelMode.DATA),
device_ids=[torch.cuda.current_device()],
**torch_ddp_cfg)
if verbose:
logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
elif is_using_ddp():
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
if verbose:
logger.info(
"Data parallel training is detected when using pipeline parallel, "
"DataParallelGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
# add pipeline parallel gradient handler, if pipeline shared module is detected
for param in model.parameters():
if getattr(param, 'pipeline_shared_module_pg', None) is not None:
if gradient_handler_cfg is None:
gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
else:
gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
if verbose:
logger.info(
"pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
"added even though not specified in the configuration",
ranks=[0])
break
else:
if not isinstance(gradient_handler_cfg, list):
raise ConfigException(
f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
)
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
# to avoid duplicated buffer synchronization
if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
model.module.sync_buffer = False
# initialize schedule for engine
if is_using_pp():
tensor_shape = get_tensor_shape()
use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
if gpc.is_initialized(ParallelMode.PARALLEL_1D):
scatter_gather = True
else:
scatter_gather = False
if use_interleaved:
if isinstance(model, nn.Sequential):
model = nn.ModuleList([model])
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather)
else:
schedule = NonPipelineSchedule()
if gradient_handler_cfg is None:
gradient_handlers = None
if verbose and not isinstance(model, DDP):
logger.warning(
"No PyTorch DDP or gradient handler is set up, please make sure you do not need "
"to all-reduce the gradients after a training step.",
ranks=[0])
else:
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
# check if optimizer is ColossalaiOptimizer
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
optimizer = ColossalaiOptimizer(optim=optimizer)
# gradient accumulation
grad_accum_size = gpc.config.get('gradient_accumulation', None)
if grad_accum_size is not None:
optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
accumulate_size=grad_accum_size,
gradient_handlers=gradient_handlers,
lr_scheduler=lr_scheduler)
engine = Engine(model=model,
optimizer=optimizer,
criterion=criterion,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm,
ophook_list=ophooks,
schedule=schedule)
return engine, train_dataloader, test_dataloader, lr_scheduler
from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention
__all__ = ["LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"]
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <torch/extension.h>
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale);
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd, float momentum, float dampening, float lr,
bool nesterov, bool first_run,
bool wd_after_momentum, float scale);
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction, const float weight_decay,
const float div_scale);
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int bias_correction,
const float weight_decay, const int grad_averaging,
const int mode, at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
}
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
\ No newline at end of file
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"
#include <math.h>
#include <omp.h>
#include <string.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
// C++ interface
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
size_t rounded_size = 0;
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
if (grad_half_precision) {
grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i);
} else {
grad_4.data = SIMD_LOAD(grads + i);
}
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
}
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
AVX_Data param_4;
if (param_half_precision) {
param_4.data = SIMD_LOAD_HALF(params_cast_h + i);
} else {
param_4.data = SIMD_LOAD(_params + i);
}
if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
}
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data =
SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data =
SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
if (_weight_decay > 0 && _adamw_mode) {
param_4.data =
SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
}
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
if (param_half_precision) {
SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data);
} else {
SIMD_STORE(_params + i, param_4.data);
}
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
if (loss_scale > 0) {
grad /= loss_scale;
}
float param =
param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
if (param_half_precision)
params_cast_h[k] = (__half)param;
else
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
}
}
}
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
size_t rounded_size = 0;
__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
AVX_Data grad_4[4];
AVX_Data momentum_4[4];
AVX_Data variance_4[4];
AVX_Data param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
if (grad_half_precision) {
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j);
} else {
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
}
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
} else {
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) {
SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j),
param_4[j].data);
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
}
}
}
#endif
if (_param_size > rounded_size)
Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale);
}
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
size_t rounded_size = 0;
__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
AVX_Data grad_4[8];
AVX_Data momentum_4[8];
AVX_Data variance_4[8];
AVX_Data param_4[8];
#pragma unroll 8
for (int j = 0; j < 8; j++) {
if (grad_half_precision) {
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j);
} else {
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
}
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
} else {
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) {
SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j),
param_4[j].data);
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
}
}
}
#endif
if (_param_size > rounded_size)
Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale);
}
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float *params_ptr = (float *)params_c.data_ptr();
float *grads_ptr = (float *)grads_c.data_ptr();
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &Adam_Optimizer::step);
}
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <hip/hip_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <torch/extension.h>
#if (__x86_64__ || __i386__)
#include <cpuid.h>
#include <x86intrin.h>
#endif
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
#if defined(__AVX512__)
#define SIMD_WIDTH 16
#define INTV __m256i
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
#define INTV __m128i
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#endif
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#elif defined(__AVX256__) or defined(__AVX2__)
__m256 data;
#endif
// float data_f[16];
};
#endif
#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1);
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~Adam_Optimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
};
#include "block_reduce.h"
#include "cuda_util.h"
#include "kernels.h"
#include "ls_cub.cuh"
ls::cub::CachingDeviceAllocator g_allocator(true);
template <typename T>
__global__ void ls_cross_entropy_fw_kernel(
const T *__restrict__ inputs, const int *__restrict__ targets,
float *__restrict__ outputs, float *__restrict__ nll_loss_outputs,
const int padding_idx, const float epsilon, const int vocab_size) {
/* step1: compute each thread's max_logit and sum_exp_logit, store in
* max_input, sum_exp_logit */
const int block_start = blockIdx.x * vocab_size;
const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
int target_tid = targets[blockIdx.x];
if (target_tid == padding_idx) {
if (threadIdx.x == 0) {
nll_loss_outputs[blockIdx.x] = 0.f;
outputs[blockIdx.x] = 0.f;
}
return;
}
for (int i = left_idx; i < right_idx; i += blockDim.x) {
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
}
blockReduce<ReduceType::kMax, 1>(max_input);
__shared__ float s_max_input;
if (threadIdx.x == 0) {
s_max_input = max_input[0];
}
__syncthreads();
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float logit = static_cast<float>(inputs[i]) - s_max_input;
sum_logits[0] += logit;
sum_logits[1] += expf(logit);
}
blockReduce<ReduceType::kSum, 2>(sum_logits);
__shared__ float s_sum_logit;
__shared__ float s_sum_exp;
if (threadIdx.x == 0) {
s_sum_logit = sum_logits[0];
s_sum_exp = sum_logits[1];
}
__syncthreads();
float eps_i = epsilon / (vocab_size - 1);
if (threadIdx.x == 0) {
// neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max)
float nll_loss = logf(s_sum_exp) -
static_cast<float>(inputs[block_start + target_tid]) +
s_max_input;
nll_loss_outputs[blockIdx.x] = nll_loss;
float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit;
outputs[blockIdx.x] =
(1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss;
}
}
template <typename T>
__global__ void ls_cross_entropy_bw_kernel(
const float *__restrict__ grad_outputs, const T *__restrict__ inputs,
const int *__restrict__ targets, T *__restrict__ grad_inputs,
const int padding_idx, const float epsilon, const int vocab_size) {
/* step1: compute each thread's max_logit and sum_exp_logit, store in
* max_input, sum_exp_logit */
const int block_start = blockIdx.x * vocab_size;
const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[1] = {0.f};
const float grad_out = static_cast<float>(grad_outputs[0]);
int target_tid = targets[blockIdx.x];
if (target_tid == padding_idx) {
for (int i = left_idx; i < right_idx; i += blockDim.x) {
grad_inputs[i] = 0.f;
}
return;
}
for (int i = left_idx; i < right_idx; i += blockDim.x) {
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
}
blockReduce<ReduceType::kMax, 1>(max_input);
__shared__ float s_max_input;
if (threadIdx.x == 0) {
s_max_input = max_input[0];
}
__syncthreads();
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float logit = static_cast<float>(inputs[i]) - s_max_input;
sum_logits[0] += expf(logit);
}
blockReduce<ReduceType::kSum, 1>(sum_logits);
__shared__ float s_sum_exp;
if (threadIdx.x == 0) {
s_sum_exp = sum_logits[0];
}
__syncthreads();
float eps_i = epsilon / (vocab_size - 1);
float nll_weight = 1.0 - epsilon - eps_i;
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float prob = expf(static_cast<float>(inputs[i]) - s_max_input) / s_sum_exp;
float grad = 0;
grad += (vocab_size * prob - 1) * eps_i;
grad += prob * nll_weight;
if ((i - block_start) == target_tid) {
grad -= nll_weight;
}
grad_inputs[i] = grad_out * grad;
}
}
template <typename T>
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
float *outputs_ptr, float *nll_loss_ptr,
float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size,
const int seq_len, const int vocab_size,
cudaStream_t stream) {
int grid_dim = batch_size * seq_len;
float *nll_loss_buffer = loss_buffer + grid_dim;
ls_cross_entropy_fw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx,
epsilon, vocab_size);
int num_items = grid_dim;
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
loss_buffer, outputs_ptr,
num_items, stream));
CHECK_GPU_ERROR(
g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
loss_buffer, outputs_ptr,
num_items, stream));
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
nll_loss_buffer, nll_loss_ptr,
num_items, stream));
CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage));
}
template void launch_cross_entropy_fw<float>(
const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template void launch_cross_entropy_fw<__half>(
const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template <typename T>
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
const int *targets_ptr, T *grad_inputs_ptr,
const int padding_idx, const float epsilon,
const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream) {
int grid_dim = batch_size * seq_len;
ls_cross_entropy_bw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx,
epsilon, vocab_size);
}
template void launch_cross_entropy_bw<float>(
const float *grad_outputs_ptr, const float *inputs_ptr,
const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template void launch_cross_entropy_bw<__half>(
const float *grad_outputs_ptr, const __half *inputs_ptr,
const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
*/
#include "cublas_wrappers.h"
#ifdef COLOSSAL_HIP
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const float *A,
const float *B, float *C, rocblas_gemm_algo algo) {
cublasStatus_t status =
rocblas_gemm_ex(handle, transa, transb, m, n, k, (const void *)alpha,
(const void *)A, rocblas_datatype_f32_r, (transa == rocblas_operation_none) ? m : k,
(const void *)B, rocblas_datatype_f32_r, (transb == rocblas_operation_none) ? k : n,
(const void *)beta, C, rocblas_datatype_f32_r, m, C, rocblas_datatype_f32_r, m, rocblas_datatype_f32_r, algo, 0, 0);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const __half *A,
const __half *B, __half *C, rocblas_gemm_algo algo) {
cublasStatus_t status = rocblas_gemm_ex(
handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A,
rocblas_datatype_f16_r, (transa == rocblas_operation_none) ? m : k, (const void *)B, rocblas_datatype_f16_r,
(transb == rocblas_operation_none) ? k : n, (const void *)beta, (void *)C,
rocblas_datatype_f16_r, m, (void *)C, rocblas_datatype_f16_r, m, rocblas_datatype_f32_r, algo, 0, 0);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const float *A, const float *B, float *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, rocblas_gemm_algo algo) {
cublasStatus_t status = rocblas_gemm_strided_batched_ex(
handle, op_A, op_B, m, n, k, alpha, A, rocblas_datatype_f32_r,
(op_A == rocblas_operation_none) ? m : k, stride_A, B, rocblas_datatype_f32_r,
(op_B == rocblas_operation_none) ? k : n, stride_B, beta, C, rocblas_datatype_f32_r, m, stride_C,
C, rocblas_datatype_f16_r, m, stride_C, batch, rocblas_datatype_f32_r, algo, 0, 0);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
"error: %d) \n",
batch, m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const __half *A, const __half *B, __half *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, rocblas_gemm_algo algo) {
cublasStatus_t status = rocblas_gemm_strided_batched_ex(
handle, op_A, op_B, m, n, k, alpha, A, rocblas_datatype_f16_r,
(op_A == rocblas_operation_none) ? m : k, stride_A, B, rocblas_datatype_f16_r,
(op_B == rocblas_operation_none) ? k : n, stride_B, beta, C, rocblas_datatype_f16_r, m, stride_C,
C, rocblas_datatype_f16_r, m, stride_C, batch, rocblas_datatype_f32_r, algo, 0, 0);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
#else
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const float *A,
const float *B, float *C, cublasGemmAlgo_t algo) {
cublasStatus_t status =
cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha,
(const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k,
(const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n,
(const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const __half *A,
const __half *B, __half *C, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmEx(
handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A,
CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F,
(transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C,
CUDA_R_16F, m, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const float *A, const float *B, float *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmStridedBatchedEx(
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F,
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F,
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C,
batch, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
"error: %d) \n",
batch, m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const __half *A, const __half *B, __half *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmStridedBatchedEx(
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F,
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F,
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C,
batch, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
#endif
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#ifdef COLOSSAL_HIP
#include <thrust/transform_reduce.h>
#include "hip_util.h"
#else
#include "cuda_util.h"
#endif
/* GPU function guard */
std::string _cudaGetErrorString(cudaError_t error) {
return cudaGetErrorString(error);
}
std::string _cudaGetErrorString(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
#ifndef COLOSSAL_HIP
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
#endif
}
return "CUBLAS_UNKNOW";
}
template <typename T>
void check_gpu_error(T result, char const *const func, const char *const file,
int const line) {
if (result) {
throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" +
std::to_string(line) +
"): " + (_cudaGetErrorString(result)) + "\n");
}
}
template void check_gpu_error<cudaError_t>(cudaError_t result,
char const *const func,
const char *const file,
int const line);
template void check_gpu_error<cublasStatus_t>(cublasStatus_t result,
char const *const func,
const char *const file,
int const line);
template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele) {
std::cout << outn << ": ";
std::vector<T> hout(num_output_ele, (T)0);
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T),
cudaMemcpyDeviceToHost);
for (int i = 0; i < num_output_ele; i++) {
std::cout << hout[i] << ", ";
}
std::cout << std::endl;
}
template <>
void print_vec<__half>(const __half *outv, std::string outn,
int num_output_ele) {
std::cout << outn << ": ";
std::vector<__half> hout(num_output_ele, (__half)0.f);
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half),
cudaMemcpyDeviceToHost);
for (int i = 0; i < num_output_ele; i++) {
std::cout << __half2float(hout[i]) << ", ";
}
std::cout << std::endl;
}
template void print_vec<float>(const float *outv, std::string outn,
int num_output_ele);
template void print_vec<int>(const int *outv, std::string outn,
int num_output_ele);
template void print_vec<__half>(const __half *outv, std::string outn,
int num_output_ele);
template <typename T>
T *cuda_malloc(size_t ele_num) {
size_t byte_size = ele_num * sizeof(T);
T *pdata = nullptr;
CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size));
return pdata;
}
template float *cuda_malloc<float>(size_t ele_num);
template __half *cuda_malloc<__half>(size_t ele_num);
template uint8_t *cuda_malloc<uint8_t>(size_t ele_num);
void cuda_free(void *pdata) {
if (pdata != nullptr) {
cudaFree(pdata);
}
}
template <typename T>
struct _isnan {
__device__ bool operator()(T a) const { return isnan(a); }
};
template <>
struct _isnan<__half> {
__device__ bool operator()(const __half a) const { return __hisnan(a); }
};
template <typename T>
struct _isinf {
__device__ bool operator()(T a) const { return isinf(a); }
};
template <>
struct _isinf<__half> {
__device__ bool operator()(const __half a) const { return __hisinf(a); }
};
template <typename T>
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
std::string file, int line, cudaStream_t stream) {
// check_nan_inf = 0 for checking nan
// check_nan_inf = 1 for checking inf
bool res = false;
std::string msg = file + "(" + std::to_string(line) + "): ";
if (check_nan_inf) {
msg += "nan.";
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
data_ptr + dsize, _isnan<T>(), false,
thrust::logical_or<bool>());
} else {
msg += "inf.";
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
data_ptr + dsize, _isinf<T>(), false,
thrust::logical_or<bool>());
}
if (res) {
throw std::runtime_error(msg);
}
std::cout << msg << " [check pass]." << std::endl;
}
template void check_nan_inf<float>(const float *data_ptr, int dsize,
bool check_nan_inf, std::string file,
int line, cudaStream_t stream);
template void check_nan_inf<__half>(const __half *data_ptr, int dsize,
bool check_nan_inf, std::string file,
int line, cudaStream_t stream);
#include <chrono>
#include <ctime>
#include "kernels.h"
#ifdef COLOSSAL_HIP
#include <hiprand/hiprand_kernel_hcc.h>
#endif
#ifndef COLOSSAL_HIP
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
#endif
curandStatePhilox4_32_10_t *curandstate;
/**
* @brief element-wise activation function on device, like Relu, Gelu
*
* @tparam enum class ActivationType, kRelu, kGelu
* @tparam input type
* @param any shape of float and __half2
* @return same shape and type with input
*/
template <ActivationType, typename T>
__forceinline__ __device__ T activation_kernel(T x);
template <>
__device__ float activation_kernel<ActivationType::kGelu, float>(float x) {
float cdf =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf;
}
template <>
__device__ __half2
activation_kernel<ActivationType::kGelu, __half2>(__half2 val) {
__half2 val_pow3 = __hmul2(val, __hmul2(val, val));
float2 tmp_pow = __half22float2(val_pow3);
float2 tmp = __half22float2(val);
tmp.x =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
tmp.y =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
return __hmul2(val, __float22half2_rn(tmp));
}
template <>
__device__ float activation_kernel<ActivationType::kRelu, float>(float x) {
return fmaxf(x, 0);
}
template <>
__device__ __half2
activation_kernel<ActivationType::kRelu, __half2>(__half2 x) {
#ifdef COLOSSAL_HIP
float2 tmp = __half22float2(x);
return __floats2half2_rn(fmaxf(0.f, tmp.x),
fmaxf(0.f, tmp.y));
#else
return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
fmaxf(0.f, __half2float(x.y)));
#endif
}
/**
* @brief element-wise activation backward function on device
*
* @tparam enum class ActivationType
* @tparam input type
* @param any shape of float and __half2
* @return same shape of input
*/
template <ActivationType, typename T>
__forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
template <>
__device__ float activation_bwd_kernel<ActivationType::kGelu, float>(float grad,
float x) {
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return grad * (dg1 + dg2 + dg3);
}
template <>
__device__ __half activation_bwd_kernel<ActivationType::kGelu, __half>(
__half grad, __half x_half) {
float x = __half2float(x_half);
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return grad * __float2half(dg1 + dg2 + dg3);
}
template <>
__device__ float activation_bwd_kernel<ActivationType::kRelu, float>(float grad,
float x) {
return x > 0.f ? grad : 0.f;
}
template <>
__device__ __half
activation_bwd_kernel<ActivationType::kRelu, __half>(__half grad, __half x) {
const __half half_zero = __float2half(0.f);
return x > half_zero ? grad : half_zero;
}
template <>
__device__ __half2 activation_bwd_kernel<ActivationType::kRelu, __half2>(
__half2 grad2, __half2 x_half2) {
#ifdef COLOSSAL_HIP
float2 tmp_x = __half22float2(x_half2);
float2 tmp_grad2 = __half22float2(grad2);
return __floats2half2_rn(tmp_x.x > 0.0 ? tmp_grad2.x : 0.0,
tmp_x.y > 0.0 ? tmp_grad2.y : 0.0);
#else
const __half half_zero = __float2half(0.f);
return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
x_half2.y > half_zero ? grad2.y : half_zero);
#endif
}
/**
* @brief init curand states in global memory
*
* @thread grid_dim * block*dim to suuport any size of states
* @param state persistant curand states
* @param seed seed to init states
* @return void
*/
__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
int seed) {
/* Each thread gets same seed, a different sequence
number, no offset */
int id = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(seed, id, 0, &state[id]);
}
void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
int grid_dim = total_count >> 9;
curand_init_kernel<<<grid_dim, 512, 0, stream>>>(
curandstate, std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
}
/**
* @brief element-wise dropout, store dropped position in mask, it's not
* in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param out any size of float and __half
* @param in same with out
* @param mask uint8 type, same size with out
* @param seed seed to curand
* @return void
*/
__global__ void ls_dropout_kernel(const int total_count, const float ratio,
float *__restrict__ out,
const float *__restrict__ in,
uint8_t *__restrict__ mask, const int seed) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
float4 input4 = data4[i];
float4 res4;
res4.x = input4.x * scale * m[0];
res4.y = input4.y * scale * m[1];
res4.z = input4.z * scale * m[2];
res4.w = input4.w * scale * m[3];
out4[i] = res4;
}
__global__ void ls_dropout_kernel(const int total_count, const float ratio,
__half *__restrict__ out,
const __half *__restrict__ in,
uint8_t *__restrict__ mask, const int seed) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = (uint8_t)(rand.x > ratio);
m[5] = (uint8_t)(rand.y > ratio);
m[6] = (uint8_t)(rand.z > ratio);
m[7] = (uint8_t)(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = *m8;
float4 val_float4 = vals_float4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
__half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
__half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
__half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
__half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
outs_float4[i] = out_float4;
}
/**
* @brief element-wise dropout backward with dropout mask, it's
* not in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param in any size of float and __half
* @param mask uint8 type, same size with in
* @return void
*/
__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
float *out, const float *in,
const uint8_t *__restrict__ mask) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *in4 = reinterpret_cast<const float4 *>(in);
const uint32_t *mask4 = reinterpret_cast<const uint32_t *>(mask);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
m4[0] = mask4[i];
float4 input4 = in4[i];
float4 res4;
res4.x = input4.x * scale * static_cast<float>(m[0]);
res4.y = input4.y * scale * static_cast<float>(m[1]);
res4.z = input4.z * scale * static_cast<float>(m[2]);
res4.w = input4.w * scale * static_cast<float>(m[3]);
out4[i] = res4;
}
__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
__half *out, const __half *in,
const uint8_t *__restrict__ mask) {
const __half scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
const uint64_t *mask8 = reinterpret_cast<const uint64_t *>(mask);
uint8_t m[8];
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
m8[0] = mask8[i];
float4 val_float4 = vals_float4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
__half2 scale_mask_1 =
__halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
__half2 scale_mask_2 =
__halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
__half2 scale_mask_3 =
__halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
__half2 scale_mask_4 =
__halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
out4[i] = out_float4;
}
template <>
void launch_ls_dropout<float>(float *out, const float *vals, uint8_t *mask,
int total_count, float ratio, cudaStream_t stream,
bool backward) {
int grid_dim = total_count >> 12;
if (!backward) {
ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
} else {
ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio,
out, vals, mask);
}
}
template <>
void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
int total_count, float ratio,
cudaStream_t stream, bool backward) {
int grid_dim = total_count >> 13;
if (!backward) {
ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
} else {
ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio,
out, vals, mask);
}
}
/**
* @brief fused bias, dropout, and residual at the end of Attention and FFN,
* store dropped position in mask, it's not in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias
* @param residual [batch_size, seq_len, hidden_size], float and __half
* @param seed seed to curand
* @param hidden_size hidden size
* @return void
*/
__global__ void ls_dropout_res_bias_kernel(
const int total_count, const float ratio, float *__restrict__ out,
const float *__restrict__ in, uint8_t *__restrict__ mask,
const float *__restrict__ bias, const float *__restrict__ residual,
const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
const float4 *residual4 = reinterpret_cast<const float4 *>(residual);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = static_cast<uint8_t>(rand.x > ratio);
m[1] = static_cast<uint8_t>(rand.y > ratio);
m[2] = static_cast<uint8_t>(rand.z > ratio);
m[3] = static_cast<uint8_t>(rand.w > ratio);
int bias_i = i % (hidden_size >> 2);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
const float4 input4 = data4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
const float4 res4 = residual4[i];
float4 output4;
output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
out4[i] = output4;
}
__global__ void ls_dropout_res_bias_kernel(
const int total_count, const float ratio, __half *__restrict__ out,
const __half *__restrict__ in, uint8_t *__restrict__ mask,
const __half *__restrict__ bias, const __half *__restrict__ residual,
const int seed, const int hidden_size) {
const __half scale = 1. / (1. - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
const float4 *residual4 = reinterpret_cast<const float4 *>(residual);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = static_cast<uint8_t>(rand.x > ratio);
m[1] = static_cast<uint8_t>(rand.y > ratio);
m[2] = static_cast<uint8_t>(rand.z > ratio);
m[3] = static_cast<uint8_t>(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = static_cast<uint8_t>(rand.x > ratio);
m[5] = static_cast<uint8_t>(rand.y > ratio);
m[6] = static_cast<uint8_t>(rand.z > ratio);
m[7] = static_cast<uint8_t>(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = m8[0];
int bias_i = i % (hidden_size >> 3);
float4 val_float4 = vals_float4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
const float4 res4 = residual4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4);
const __half2 *res_half2 = reinterpret_cast<const __half2 *>(&res4);
__half2 scale_mask_1 =
__halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
__half2 scale_mask_2 =
__halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
__half2 scale_mask_3 =
__halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
__half2 scale_mask_4 =
__halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
out_half2[0] =
__hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]);
out_half2[1] =
__hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]);
out_half2[2] =
__hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]);
out_half2[3] =
__hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]);
outs_float4[i] = out_float4;
}
template <>
void launch_ls_dropout_res_bias<float>(float *out, const float *vals,
uint8_t *mask, const float *bias,
const float *residual, int total_count,
int dim, float ratio,
cudaStream_t stream) {
int grid_dim = total_count >> 12;
ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, residual,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals,
uint8_t *mask, const __half *bias,
const __half *residual, int total_count,
int dim, float ratio,
cudaStream_t stream) {
int grid_dim = total_count >> 13;
ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, residual,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
/**
* @brief fused bias and dropout backward at the end of Attention and FFN
*
* @thread
* gridDim.x = hidden_size / 8
* blockDim.x = 8
* blockDim.y = 1024 / 8 = 128
*
* @param row_size batch_size * seq_len
* @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size
* @return void
*/
__global__ void ls_dropout_bias_bwd_kernel(
const int row_size, const float ratio, float *__restrict__ in_grad,
float *__restrict__ bias_grad, const float *__restrict__ out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
// every block generate 8 bias result
__shared__ float tile[8][129];
#ifndef COLOSSAL_HIP
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
#endif
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
int stride = hidden_size * 128;
float local_sum = 0;
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
for (int r = threadIdx.y; r < row_size; r += 128) {
float val = out_grad[idx];
val *= scale * static_cast<float>(mask[idx]);
local_sum += val;
in_grad[idx] = val;
idx += stride;
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
float sum = 0;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 7;
int y = tid & (127);
if (y < 32) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sum += tile[x][y + i * 32];
}
}
__syncthreads();
#ifdef COLOSSAL_HIP
for (int i = 1; i < 32; i <<= 1) sum += __shfl_down(sum, i);
#else
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
#endif
if (y == 0) tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
bias_grad[pos] = tile[0][threadIdx.x];
}
}
__global__ void ls_dropout_bias_bwd_kernel(
const int row_size, const float ratio, __half *__restrict__ in_grad,
__half *__restrict__ bias_grad, const __half *__restrict__ out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
__shared__ __half2 tile[8][129];
#ifndef COLOSSAL_HIP
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
#endif
__half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
__half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
int stride = hidden_size * 128;
__half2 local_sum = __float2half2_rn(0.f);
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
for (int r = threadIdx.y; r < row_size; r += 128) {
__half2 val = out_grad2[idx];
__half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
val *= scale * m2;
local_sum += val;
in_grad2[idx] = val;
idx += stride;
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
__half2 sum = __float2half2_rn(0.f);
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 7;
int y = tid & (127);
if (y < 32) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sum += tile[x][y + i * 32];
}
}
__syncthreads();
#ifdef COLOSSAL_HIP
float2 sum_f2 = __half22float2(sum);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum_f2.x += __shfl_down(sum_f2.x, i);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum_f2.y += __shfl_down(sum_f2.y, i);
sum = __float22half2_rn(sum_f2);
#else
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
#endif
if (y == 0) tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
bias_grad2[pos] = tile[0][threadIdx.x];
}
}
template <typename T>
void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream) {
dim3 grid_dim((dim - 1) / 8 + 1);
dim3 block_dim(8, 128);
ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
}
template <>
void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad,
const __half *out_grad, const uint8_t *mask,
int row_size, int dim, float ratio,
cudaStream_t stream) {
dim >>= 1;
dim3 grid_dim((dim - 1) / 8 + 1);
dim3 block_dim(8, 128);
ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
}
template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad,
const float *out_grad,
const uint8_t *mask, int row_size,
int dim, float ratio,
cudaStream_t stream);
/**
* @brief fused bias, activation, and dropout at the end of first ffn
*
* @thread
* gridDim.x = hidden_size / 8
* blockDim.x = 8
* blockDim.y = 1024 / 8 = 128
*
* @tparam act_type activation function, like kRelu, kGelu
* @param total_count total elements
* @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias
* @param seed seed to curand
* @param hidden_size
* @return void
*/
template <ActivationType act_type>
__global__ void ls_dropout_act_bias_kernel(
const int total_count, const float ratio, float *__restrict__ out,
const float *__restrict__ in, uint8_t *__restrict__ mask,
const float *__restrict__ bias, const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int bias_i = i % (hidden_size >> 2);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
const float4 input4 = data4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
float4 output4;
output4.x =
activation_kernel<act_type, float>(input4.x + b4.x) * scale * m[0];
output4.y =
activation_kernel<act_type, float>(input4.y + b4.y) * scale * m[1];
output4.z =
activation_kernel<act_type, float>(input4.z + b4.z) * scale * m[2];
output4.w =
activation_kernel<act_type, float>(input4.w + b4.w) * scale * m[3];
out4[i] = output4;
}
template <ActivationType act_type>
__global__ void ls_dropout_act_bias_kernel(
const int total_count, const float ratio, __half *__restrict__ out,
const __half *__restrict__ in, uint8_t *__restrict__ mask,
const __half *__restrict__ bias, const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = (uint8_t)(rand.x > ratio);
m[5] = (uint8_t)(rand.y > ratio);
m[6] = (uint8_t)(rand.z > ratio);
m[7] = (uint8_t)(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = *m8;
int bias_i = i % (hidden_size >> 3);
float4 val_float4 = vals_float4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4);
__half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
__half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
__half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
__half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
out_half2[0] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[0], b_half2[0])),
scale_mask_1);
out_half2[1] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[1], b_half2[1])),
scale_mask_2);
out_half2[2] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[2], b_half2[2])),
scale_mask_3);
out_half2[3] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[3], b_half2[3])),
scale_mask_4);
outs_float4[i] = out_float4;
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kGelu, float>(
float *out, const float *vals, uint8_t *mask, const float *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 10;
ls_dropout_act_bias_kernel<ActivationType::kGelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kGelu, __half>(
__half *out, const __half *vals, uint8_t *mask, const __half *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 11;
ls_dropout_act_bias_kernel<ActivationType::kGelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kRelu, float>(
float *out, const float *vals, uint8_t *mask, const float *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 10;
ls_dropout_act_bias_kernel<ActivationType::kRelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kRelu, __half>(
__half *out, const __half *vals, uint8_t *mask, const __half *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 11;
ls_dropout_act_bias_kernel<ActivationType::kRelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
/**
* @brief fused bias, activation, and dropout backward
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @tparam act_type kRelu
* @param row_size batch_size * seq_len
* @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size
* @return void
*/
template <ActivationType act_type, typename T>
__global__ void ls_dropout_act_bias_bwd_kernel(
const int row_size, const float ratio, T *in_grad,
T *__restrict__ bias_grad, const T *__restrict__ input,
const T *__restrict__ bias, const T *out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
__shared__ float tile[WARP_SIZE][WARP_SIZE + 1];
#ifndef COLOSSAL_HIP
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
#endif
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int stride = hidden_size * WARP_SIZE;
float local_sum = 0;
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
if (col_idx < hidden_size) {
for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
float val = out_grad[idx];
float in = input[idx];
float b = bias[idx % hidden_size];
val = activation_bwd_kernel<act_type, float>(
val * scale * static_cast<float>(mask[idx]), in + b);
local_sum += val;
in_grad[idx] = val;
idx += stride;
}
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
#ifdef COLOSSAL_HIP
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += __shfl_down(sum, i);
#else
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
#endif
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
__syncthreads();
if (threadIdx.y == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
bias_grad[pos] = tile[0][threadIdx.x];
}
}
// @brief fused bias, activation, and dropout backward
// It is deprecated for precision reason. Keep it for future optimization.
//
// template <ActivationType act_type>
// __global__ void ls_dropout_act_bias_bwd_kernel(
// const int row_size, const float ratio, __half * in_grad,
// __half *__restrict__ bias_grad, const __half *__restrict__ input, const
// __half *__restrict__ bias, const __half * out_grad, const uint8_t
// *__restrict__ mask, const int hidden_size) {
// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1];
// cg::thread_block b = cg::this_thread_block();
// cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
// const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
// const __half2 *input2 = reinterpret_cast<const __half2 *>(input);
// const __half2 *bias2 = reinterpret_cast<const __half2 *>(bias);
// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// int stride = hidden_size * WARP_SIZE;
// __half2 local_sum = __float2half2_rn(0.f);
// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
// if (col_idx < hidden_size) {
// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
// __half2 val = out_grad2[idx];
// __half2 in2 = input2[idx];
// __half2 b2 = bias2[idx % hidden_size ];
// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
// val = activation_bwd_kernel<ActivationType::kRelu, __half2>(val * scale
// *
// m2,
// in2+b2);
// local_sum += val;
// in_grad2[idx] = val;
// idx += stride;
// }
// }
// tile[threadIdx.x][threadIdx.y] = local_sum;
// __syncthreads();
// __half2 sum = tile[threadIdx.y][threadIdx.x];
// __syncthreads();
// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
// __syncthreads();
// if (threadIdx.y == 0) {
// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// bias_grad2[pos] = tile[0][threadIdx.x];
// }
// }
template <ActivationType act_type, typename T>
void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
const T *bias, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream) {
dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
ls_dropout_act_bias_bwd_kernel<act_type><<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim);
}
// template <>
// void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
// __half *in_grad, __half *bias_grad,const __half *input, const __half
// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int
// dim, float ratio, cudaStream_t stream) {
// dim >>= 1;
// dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
// dim3 block_dim(WARP_SIZE, WARP_SIZE);
// ls_dropout_act_bias_bwd_kernel<ActivationType::kRelu>
// <<<grid_dim, block_dim, 0, stream>>>(row_size, ratio, in_grad,
// bias_grad,
// input, bias,out_grad, mask, dim);
// }
template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, float>(
float *in_grad, float *bias_grad, const float *input, const float *bias,
const float *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
__half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
const __half *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, float>(
float *in_grad, float *bias_grad, const float *input, const float *bias,
const float *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, __half>(
__half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
const __half *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
#include "kernels.h"
#ifndef COLOSSAL_HIP
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
#endif
#include "kernels.h"
/**
@brief: fuse_transpose_bias
Calculate the sum of elements in each column of the matrix.
@thread
gridDim.x = ceil(cols / WARP_SIZE)
blockDim.x = WARP_SIZE
blockDim.y = WARP_SIZE
@param
inp: [rows, cols]
out: [cols]
rows: the number of rows in the matrix
cols: the number of cols in the matrix
*/
template <typename T>
__global__ void column_sum_reduce(const T *__restrict__ inp,
T *__restrict__ out, int rows, int cols) {
__shared__ float tile[WARP_SIZE][WARP_SIZE];
#ifndef COLOSSAL_HIP
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
#endif
int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int y_stride = cols * WARP_SIZE;
float localSum = 0;
// Loop across matrix row
// TODO: optimize to log complexity
if (idx < cols) {
int offset = flat_2dim(threadIdx.y, idx, cols);
for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
// The sum of a row in tile is equal to the sum of a col in original matrix
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
// The change of threadIdx.x is continuous
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
// Calculate the sum of a row in tile
#ifdef COLOSSAL_HIP
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += __shfl_down(sum, i);
#else
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
#endif
if (threadIdx.x == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
if (pos < cols) out[pos] = sum;
}
}
// [r, c] -> [c]
template <>
void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<float>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<__half>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
/**
@brief: fused_add2
Add two matrix inp1 and inp2 to out.
@thread
gridDim.x = batch_size * seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
inp1: [batch_size, seq_len, hidden_dim]
inp2: [batch_size, seq_len, hidden_dim]
out: [batch_size, seq_len, hidden_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
*/
template <typename T>
__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
int hidden_dim);
template <>
__global__ void fused_add2_kernel<float>(float *out, const float *inp1,
const float *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
val.x = vinp1.x + vinp2.x;
val.y = vinp1.y + vinp2.y;
val.z = vinp1.z + vinp2.z;
val.w = vinp1.w + vinp2.w;
out_4[offset + i] = val;
}
}
template <>
__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
const __half *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
__half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1);
__half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2);
__half2 *h2_val = reinterpret_cast<__half2 *>(&val);
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]);
h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]);
h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]);
h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]);
out_4[offset + i] = val;
}
}
//[b, s, h] -> [b, s, h]
template <>
void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2,
int batch_size, int seq_len, int hidden_dim,
cudaStream_t &stream) {
hidden_dim >>= 2;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <>
void launch_fused_add2<__half>(__half *out, const __half *inp1,
const __half *inp2, int batch_size, int seq_len,
int hidden_dim, cudaStream_t &stream) {
hidden_dim >>= 3;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <typename T>
__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output,
int sz0, int sz2, int sz1_1, int sz1_2) {
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x);
if (idx >= nele) {
return;
}
float4 *dst_ptr = (float4 *)output + idx;
int idx2 = idx % sz2;
idx = idx / sz2;
int idx1 = idx % (sz1_1 + sz1_2);
int idx0 = idx / (sz1_1 + sz1_2);
float4 *src_ptr = nullptr;
int sz1 = 0;
if (idx1 < sz1_1) {
sz1 = sz1_1;
src_ptr = (float4 *)inp1;
} else {
idx1 -= sz1_1;
sz1 = sz1_2;
src_ptr = (float4 *)inp2;
}
src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2);
dst_ptr[0] = src_ptr[0];
}
template <>
void launch_concat3_dim1<float>(const float *inp1, const float *inp2,
float *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 2;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}
template <>
void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2,
__half *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 3;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment