Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, List, Optional, Union
import torch
import torch.nn as nn
......@@ -15,7 +15,7 @@ from .kvcache_manager import MemoryManager
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"]
class TPInferEngine:
......@@ -39,14 +39,16 @@ class TPInferEngine:
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""
def __init__(self,
model: nn.Module,
shard_config: ShardConfig,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: str = 'cuda') -> None:
def __init__(
self,
model: nn.Module,
shard_config: ShardConfig,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: str = "cuda",
) -> None:
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
......@@ -63,7 +65,7 @@ class TPInferEngine:
self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
self.shard_config = shard_config
......@@ -74,9 +76,10 @@ class TPInferEngine:
def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
self.layer_num)
self.head_num //= self.tp_size # update sharded number of heads
self.cache_manager = MemoryManager(
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
)
def _optimize_model(self, model: nn.Module) -> None:
"""
......@@ -90,7 +93,7 @@ class TPInferEngine:
self._shard_model_by(shardformer, model)
def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
""" Prepare the engine with a given ShardConfig.
"""Prepare the engine with a given ShardConfig.
Args:
shard_config (ShardConfig): shard config given to specify settings of the engine.
......@@ -118,9 +121,10 @@ class TPInferEngine:
return shard_config
def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
""" Shard original model by the given ShardFormer and store the sharded model. """
assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
"Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
"""Shard original model by the given ShardFormer and store the sharded model."""
assert (
self.tp_size == shardformer.shard_config.tensor_parallel_size
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True)
......@@ -147,7 +151,7 @@ class TPInferEngine:
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].cuda()
if 'max_new_tokens' not in generate_kwargs:
if "max_new_tokens" not in generate_kwargs:
generate_kwargs.update(max_new_tokens=self.max_output_len)
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)
......@@ -176,18 +180,18 @@ class TPInferEngine:
attention_mask = None
if isinstance(inputs, (BatchEncoding, dict)):
input_ids_list = inputs['input_ids']
attention_mask = inputs['attention_mask']
input_ids_list = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
else:
input_ids_list = inputs
if isinstance(input_ids_list[0], int): # for a single input
if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
......@@ -210,10 +214,10 @@ class TPInferEngine:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to('cuda')
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
batch_infer_state.seq_len = seq_lengths.to("cuda")
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
......@@ -248,7 +252,7 @@ class TPInferEngine:
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer
setattr(model, 'infer_state', batch_infer_state)
setattr(model, "infer_state", batch_infer_state)
outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
......@@ -262,14 +266,15 @@ class TPInferEngine:
# as an arg into model.forward.
# It requires rewriting model generate and replacing model forward.
@torch.no_grad()
def _generate_by_pass_infer_state(self,
input_tokens,
max_out_length: int,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
def _generate_by_pass_infer_state(
self,
input_tokens,
max_out_length: int,
generation_config: Optional[GenerationConfig] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
raise NotImplementedError("generate by passing BatchInferState is not implemented.")
# might want to use in rewritten generate method: use after model.forward
......
......@@ -19,13 +19,15 @@ class MemoryManager:
device: device used to store the key and value cache
"""
def __init__(self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: torch.device = torch.device('cuda')):
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: torch.device = torch.device("cuda"),
):
self.logger = logging.get_logger(__name__)
self.available_size = size
self.past_key_values_length = 0
......@@ -33,13 +35,13 @@ class MemoryManager:
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
def _init_mem_states(self, size, device):
""" Initialize tensors used to manage memory states """
"""Initialize tensors used to manage memory states"""
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
""" Initialize key buffer and value buffer on specified device """
"""Initialize key buffer and value buffer on specified device"""
self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
]
......@@ -49,10 +51,9 @@ class MemoryManager:
@torch.no_grad()
def alloc(self, required_size):
""" allocate space of required_size by providing indexes representing available physical spaces """
"""allocate space of required_size by providing indexes representing available physical spaces"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} "
f"left_size {self.available_size}")
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
......@@ -63,23 +64,25 @@ class MemoryManager:
@torch.no_grad()
def alloc_contiguous(self, required_size):
""" allocate contiguous space of required_size """
"""allocate contiguous space of required_size"""
if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} "
f"left_size {self.available_size}")
self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum)
loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
1] + self.mem_state[0:sum_size -
required_size + 1]
can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
loc_sums = (
self.mem_cum_sum[required_size - 1 :]
- self.mem_cum_sum[0 : sum_size - required_size + 1]
+ self.mem_state[0 : sum_size - required_size + 1]
)
can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0:
self.logger.info(f"No enough contiguous cache: required_size {required_size} "
f"left_size {self.available_size}")
self.logger.info(
f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
)
return None
start_loc = can_used_loc[0]
select_index = self.indexes[start_loc:start_loc + required_size]
select_index = self.indexes[start_loc : start_loc + required_size]
self.mem_state[select_index] = 0
self.available_size -= len(select_index)
start = start_loc.item()
......@@ -88,13 +91,13 @@ class MemoryManager:
@torch.no_grad()
def free(self, free_index):
""" free memory by updating memory states based on given indexes """
"""free memory by updating memory states based on given indexes"""
self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1
@torch.no_grad()
def free_all(self):
""" free all memory by updating memory states """
"""free all memory by updating memory states"""
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.past_key_values_length = 0
......
from .bloom import BloomInferenceForwards
from .llama import LlamaInferenceForwards
__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards']
__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"]
import math
import warnings
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -31,17 +31,17 @@ def generate_alibi(n_head, dtype=torch.float16):
"""
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)]
def get_slopes(n):
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
slopes_double = get_slopes(2 * closest_power_of_2)
slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2]
slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
return slopes_combined
slopes = get_slopes(n_head)
......@@ -72,7 +72,6 @@ class BloomInferenceForwards:
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
......@@ -86,8 +85,9 @@ class BloomInferenceForwards:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -122,14 +122,15 @@ class BloomInferenceForwards:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# NOTE determine if BatchInferState is passed in via arg
# if not, get the attr binded to the model
# We might wantto remove setattr later
if infer_state is None:
assert hasattr(self, 'infer_state')
assert hasattr(self, "infer_state")
infer_state = self.infer_state
# Compute alibi tensor: check build_alibi_tensor documentation
......@@ -146,10 +147,11 @@ class BloomInferenceForwards:
if use_cache and seq_length != 1:
# prefill stage
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
infer_state.context_mem_index)
BatchInferState.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
......@@ -182,8 +184,11 @@ class BloomInferenceForwards:
# alibi = generate_alibi(self.num_heads).contiguous().cuda()
tp_size = dist.get_world_size()
curr_tp_rank = dist.get_rank()
alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) *
self.num_heads].cuda()
alibi = (
generate_alibi(self.num_heads * tp_size)
.contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
.cuda()
)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
......@@ -197,7 +202,6 @@ class BloomInferenceForwards:
if self.gradient_checkpointing and self.training:
# NOTE: currently our KV cache manager does not handle this condition
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
......@@ -250,32 +254,34 @@ class BloomInferenceForwards:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents, # should always be (None, None, ..., None)
past_key_values=presents, # should always be (None, None, ..., None)
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@staticmethod
def bloom_for_causal_lm_forward(self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments):
def bloom_for_causal_lm_forward(
self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: Optional[BatchInferState] = None,
**deprecated_arguments,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
logger = logging.get_logger(__name__)
logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
......@@ -289,17 +295,19 @@ class BloomInferenceForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state)
transformer_outputs = BloomInferenceForwards.bloom_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
......@@ -314,8 +322,9 @@ class BloomInferenceForwards:
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
......@@ -353,11 +362,13 @@ class BloomInferenceForwards:
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update({
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
})
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
......@@ -416,7 +427,7 @@ class BloomInferenceForwards:
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
return outputs # hidden_states, present, attentions
@staticmethod
def bloom_attention_forward(
......@@ -431,20 +442,19 @@ class BloomInferenceForwards:
output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, H, D_HEAD = query_layer.shape
k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
if layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_length # += 1
if layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_length # += 1
if infer_state.is_context_stage:
# context process
......@@ -471,9 +481,11 @@ class BloomInferenceForwards:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(k)
cache_v.copy_(v)
else:
......@@ -486,8 +498,17 @@ class BloomInferenceForwards:
b_loc = infer_state.block_loc
b_seq_len = infer_state.seq_len
output = torch.empty_like(q)
token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc,
b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi)
token_attention_fwd(
q,
mem_manager.key_buffer[layer_id],
mem_manager.value_buffer[layer_id],
output,
b_loc,
b_start_loc,
b_seq_len,
infer_state.cache_manager.past_key_values_length,
alibi,
)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
......@@ -504,8 +525,8 @@ class BloomInferenceForwards:
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
......
from typing import List, Optional, Tuple
import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
......@@ -15,6 +14,7 @@ from colossalai.kernel.triton import (
try:
from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
......@@ -29,17 +29,17 @@ except:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
......@@ -71,8 +71,7 @@ class LlamaInferenceForwards:
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch_size = input_ids.shape[0] # input_ids.shape[0]
batch_size = input_ids.shape[0] # input_ids.shape[0]
infer_state = self.infer_state
......@@ -103,10 +102,11 @@ class LlamaInferenceForwards:
if use_cache and seq_length != 1:
# NOTE assuem prefill stage
# allocate memory block
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
infer_state.context_mem_index)
infer_state.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
......@@ -129,20 +129,20 @@ class LlamaInferenceForwards:
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if infer_state.is_context_stage:
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1)
position_ids.view(-1).shape[0], -1
)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1)
position_ids.view(-1).shape[0], -1
)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
......@@ -153,12 +153,13 @@ class LlamaInferenceForwards:
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
......@@ -216,7 +217,6 @@ class LlamaInferenceForwards:
use_cache: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -261,7 +261,6 @@ class LlamaInferenceForwards:
use_cache: bool = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert use_cache is True, "use_cache should be set to True using this llama attention"
bsz, q_len, _ = hidden_states.size()
......@@ -277,8 +276,8 @@ class LlamaInferenceForwards:
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
......@@ -299,38 +298,62 @@ class LlamaInferenceForwards:
# first token generation
# copy key and value calculated in current step to memory manager
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index,
infer_state.cache_manager)
_copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc,
infer_state.seq_len, infer_state.cache_manager.past_key_values_length)
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(key_states)
cache_v.copy_(value_states)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states,
infer_state.decode_mem_index, infer_state.cache_manager)
_copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.decode_mem_index,
infer_state.cache_manager,
)
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output,
infer_state.block_loc, infer_state.start_loc, infer_state.seq_len,
infer_state.cache_manager.past_key_values_length)
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
......@@ -341,7 +364,6 @@ class LlamaInferenceForwards:
def get_llama_vllm_rmsnorm_forward():
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
......
from .bloom import BloomModelInferPolicy
from .llama import LlamaModelInferPolicy
__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy']
__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"]
......@@ -9,6 +9,7 @@ from ..modeling.bloom import BloomInferenceForwards
try:
from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True
except:
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
......@@ -27,40 +28,40 @@ def get_triton_layernorm_forward():
class BloomModelInferPolicy(BloomForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
# NOTE set inference mode to shard config
self.shard_config._infer()
method_replacement = {
'forward': BloomInferenceForwards.bloom_for_causal_lm_forward,
'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
"forward": BloomInferenceForwards.bloom_for_causal_lm_forward,
"prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation,
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomForCausalLM)
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomForCausalLM
)
method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomAttention)
method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=BloomAttention
)
if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward()
method_replacement = {'forward': partial(infer_method)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LayerNorm)
method_replacement = {"forward": partial(infer_method)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LayerNorm
)
return policy
......@@ -10,6 +10,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw
try:
from colossalai.kernel.triton import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
......@@ -28,7 +29,6 @@ def get_triton_rmsnorm_forward():
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
......@@ -37,20 +37,20 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaDecoderLayer)
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaAttention)
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
infer_forward = None
if HAS_TRITON_RMSNORM:
......@@ -60,9 +60,9 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaRMSNorm)
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)
return policy
......@@ -14,15 +14,17 @@ from colossalai.logging import get_dist_logger
from colossalai.utils import set_device, set_seed
def launch(config: Union[str, Path, Config, Dict],
rank: int,
world_size: int,
host: str,
port: int,
backend: str = 'nccl',
local_rank: int = None,
seed: int = 1024,
verbose: bool = True):
def launch(
config: Union[str, Path, Config, Dict],
rank: int,
world_size: int,
host: str,
port: int,
backend: str = "nccl",
local_rank: int = None,
seed: int = 1024,
verbose: bool = True,
):
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
......@@ -46,7 +48,7 @@ def launch(config: Union[str, Path, Config, Dict],
warnings.warn("`config` is deprecated and will be removed soon.")
# init default process group
init_method = f'tcp://[{host}]:{port}'
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
......@@ -58,15 +60,17 @@ def launch(config: Union[str, Path, Config, Dict],
if verbose:
logger = get_dist_logger()
logger.info(f'Distributed environment is initialized, world size: {dist.get_world_size()}', ranks=[0])
logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
def launch_from_slurm(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
def launch_from_slurm(
config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = "nccl",
seed: int = 1024,
verbose: bool = True,
):
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM
......@@ -79,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NPROCS'])
rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NPROCS"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
)
launch(config=config,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
launch(
config=config,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose,
)
def launch_from_openmpi(
config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = "nccl",
seed: int = 1024,
verbose: bool = True,
):
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI
......@@ -114,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
)
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
def launch_from_torch(config: Union[str, Path, Config, Dict],
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
launch(
config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose,
)
def launch_from_torch(
config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
......@@ -147,22 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
launch(config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose)
launch(
config=config,
local_rank=local_rank,
rank=rank,
world_size=world_size,
host=host,
port=port,
backend=backend,
seed=seed,
verbose=verbose,
)
from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
__all__ = ["OptimizerWrapper", "ModelWrapper", "AMPModelMixin"]
......@@ -26,11 +26,9 @@ class ModelWrapper(nn.Module):
class AMPModelMixin:
"""This mixin class defines the interface for AMP training.
"""
"""This mixin class defines the interface for AMP training."""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
pass
......@@ -22,7 +22,7 @@ class OptimizerWrapper:
params = []
for group in self.param_groups:
params += group['params']
params += group["params"]
return params
@property
......@@ -82,12 +82,14 @@ class OptimizerWrapper:
"""
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
def clip_grad_by_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs,
) -> Tensor:
"""
Clips gradient norm of an iterable of parameters.
......@@ -113,7 +115,8 @@ class OptimizerWrapper:
loss (Tensor): The loss to be scaled.
"""
raise NotImplementedError(
"The method scale_loss is only available for optimizers with mixed precision training")
"The method scale_loss is only available for optimizers with mixed precision training"
)
def unscale_grad(self):
"""
......@@ -122,7 +125,8 @@ class OptimizerWrapper:
Note: Only available for optimizers with mixed precision training.
"""
raise NotImplementedError(
"The method unscale_grad is only available for optimizers with mixed precision training")
"The method unscale_grad is only available for optimizers with mixed precision training"
)
def unwrap(self):
"""
......
......@@ -4,6 +4,10 @@ from .multihead_attention import MultiHeadAttention
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
__all__ = [
'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention',
'AttnMaskType'
"LayerNorm",
"MultiHeadAttention",
"FusedScaleMaskSoftmax",
"ScaledUpperTriangMaskedSoftmax",
"ColoAttention",
"AttnMaskType",
]
......@@ -7,4 +7,4 @@
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
\ No newline at end of file
#endif
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include "cuda_util.h"
/* GPU function guard */
......
#include <chrono>
#include <ctime>
#include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate;
/**
* @brief element-wise activation function on device, like Relu, Gelu
*
* @tparam enum class ActivationType, kRelu, kGelu
* @tparam input type
* @param any shape of float and __half2
* @return same shape and type with input
*/
template <ActivationType, typename T>
__forceinline__ __device__ T activation_kernel(T x);
template <>
__device__ float activation_kernel<ActivationType::kGelu, float>(float x) {
float cdf =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf;
}
template <>
__device__ __half2
activation_kernel<ActivationType::kGelu, __half2>(__half2 val) {
__half2 val_pow3 = __hmul2(val, __hmul2(val, val));
float2 tmp_pow = __half22float2(val_pow3);
float2 tmp = __half22float2(val);
tmp.x =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
tmp.y =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
return __hmul2(val, __float22half2_rn(tmp));
}
template <>
__device__ float activation_kernel<ActivationType::kRelu, float>(float x) {
return fmaxf(x, 0);
}
template <>
__device__ __half2
activation_kernel<ActivationType::kRelu, __half2>(__half2 x) {
return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
fmaxf(0.f, __half2float(x.y)));
}
/**
* @brief element-wise activation backward function on device
*
* @tparam enum class ActivationType
* @tparam input type
* @param any shape of float and __half2
* @return same shape of input
*/
template <ActivationType, typename T>
__forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
template <>
__device__ float activation_bwd_kernel<ActivationType::kGelu, float>(float grad,
float x) {
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return grad * (dg1 + dg2 + dg3);
}
template <>
__device__ __half activation_bwd_kernel<ActivationType::kGelu, __half>(
__half grad, __half x_half) {
float x = __half2float(x_half);
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return grad * __float2half(dg1 + dg2 + dg3);
}
template <>
__device__ float activation_bwd_kernel<ActivationType::kRelu, float>(float grad,
float x) {
return x > 0.f ? grad : 0.f;
}
template <>
__device__ __half
activation_bwd_kernel<ActivationType::kRelu, __half>(__half grad, __half x) {
const __half half_zero = __float2half(0.f);
return x > half_zero ? grad : half_zero;
}
template <>
__device__ __half2 activation_bwd_kernel<ActivationType::kRelu, __half2>(
__half2 grad2, __half2 x_half2) {
const __half half_zero = __float2half(0.f);
return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
x_half2.y > half_zero ? grad2.y : half_zero);
}
/**
* @brief init curand states in global memory
*
* @thread grid_dim * block*dim to suuport any size of states
* @param state persistant curand states
* @param seed seed to init states
* @return void
*/
__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
int seed) {
/* Each thread gets same seed, a different sequence
number, no offset */
int id = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(seed, id, 0, &state[id]);
}
void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
int grid_dim = total_count >> 9;
curand_init_kernel<<<grid_dim, 512, 0, stream>>>(
curandstate, std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
}
/**
* @brief element-wise dropout, store dropped position in mask, it's not
* in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param out any size of float and __half
* @param in same with out
* @param mask uint8 type, same size with out
* @param seed seed to curand
* @return void
*/
__global__ void ls_dropout_kernel(const int total_count, const float ratio,
float *__restrict__ out,
const float *__restrict__ in,
uint8_t *__restrict__ mask, const int seed) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
float4 input4 = data4[i];
float4 res4;
res4.x = input4.x * scale * m[0];
res4.y = input4.y * scale * m[1];
res4.z = input4.z * scale * m[2];
res4.w = input4.w * scale * m[3];
out4[i] = res4;
}
__global__ void ls_dropout_kernel(const int total_count, const float ratio,
__half *__restrict__ out,
const __half *__restrict__ in,
uint8_t *__restrict__ mask, const int seed) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = (uint8_t)(rand.x > ratio);
m[5] = (uint8_t)(rand.y > ratio);
m[6] = (uint8_t)(rand.z > ratio);
m[7] = (uint8_t)(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = *m8;
float4 val_float4 = vals_float4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
__half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
__half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
__half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
__half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
outs_float4[i] = out_float4;
}
/**
* @brief element-wise dropout backward with dropout mask, it's
* not in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param in any size of float and __half
* @param mask uint8 type, same size with in
* @return void
*/
__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
float *out, const float *in,
const uint8_t *__restrict__ mask) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *in4 = reinterpret_cast<const float4 *>(in);
const uint32_t *mask4 = reinterpret_cast<const uint32_t *>(mask);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
m4[0] = mask4[i];
float4 input4 = in4[i];
float4 res4;
res4.x = input4.x * scale * static_cast<float>(m[0]);
res4.y = input4.y * scale * static_cast<float>(m[1]);
res4.z = input4.z * scale * static_cast<float>(m[2]);
res4.w = input4.w * scale * static_cast<float>(m[3]);
out4[i] = res4;
}
__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
__half *out, const __half *in,
const uint8_t *__restrict__ mask) {
const __half scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
const uint64_t *mask8 = reinterpret_cast<const uint64_t *>(mask);
uint8_t m[8];
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
m8[0] = mask8[i];
float4 val_float4 = vals_float4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
__half2 scale_mask_1 =
__halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
__half2 scale_mask_2 =
__halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
__half2 scale_mask_3 =
__halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
__half2 scale_mask_4 =
__halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
out4[i] = out_float4;
}
template <>
void launch_ls_dropout<float>(float *out, const float *vals, uint8_t *mask,
int total_count, float ratio, cudaStream_t stream,
bool backward) {
int grid_dim = total_count >> 12;
if (!backward) {
ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
} else {
ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio,
out, vals, mask);
}
}
template <>
void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
int total_count, float ratio,
cudaStream_t stream, bool backward) {
int grid_dim = total_count >> 13;
if (!backward) {
ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
} else {
ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio,
out, vals, mask);
}
}
/**
* @brief fused bias, dropout, and residual at the end of Attention and FFN,
* store dropped position in mask, it's not in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias
* @param residual [batch_size, seq_len, hidden_size], float and __half
* @param seed seed to curand
* @param hidden_size hidden size
* @return void
*/
__global__ void ls_dropout_res_bias_kernel(
const int total_count, const float ratio, float *__restrict__ out,
const float *__restrict__ in, uint8_t *__restrict__ mask,
const float *__restrict__ bias, const float *__restrict__ residual,
const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
const float4 *residual4 = reinterpret_cast<const float4 *>(residual);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = static_cast<uint8_t>(rand.x > ratio);
m[1] = static_cast<uint8_t>(rand.y > ratio);
m[2] = static_cast<uint8_t>(rand.z > ratio);
m[3] = static_cast<uint8_t>(rand.w > ratio);
int bias_i = i % (hidden_size >> 2);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
const float4 input4 = data4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
const float4 res4 = residual4[i];
float4 output4;
output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
out4[i] = output4;
}
__global__ void ls_dropout_res_bias_kernel(
const int total_count, const float ratio, __half *__restrict__ out,
const __half *__restrict__ in, uint8_t *__restrict__ mask,
const __half *__restrict__ bias, const __half *__restrict__ residual,
const int seed, const int hidden_size) {
const __half scale = 1. / (1. - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
const float4 *residual4 = reinterpret_cast<const float4 *>(residual);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = static_cast<uint8_t>(rand.x > ratio);
m[1] = static_cast<uint8_t>(rand.y > ratio);
m[2] = static_cast<uint8_t>(rand.z > ratio);
m[3] = static_cast<uint8_t>(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = static_cast<uint8_t>(rand.x > ratio);
m[5] = static_cast<uint8_t>(rand.y > ratio);
m[6] = static_cast<uint8_t>(rand.z > ratio);
m[7] = static_cast<uint8_t>(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = m8[0];
int bias_i = i % (hidden_size >> 3);
float4 val_float4 = vals_float4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
const float4 res4 = residual4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4);
const __half2 *res_half2 = reinterpret_cast<const __half2 *>(&res4);
__half2 scale_mask_1 =
__halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
__half2 scale_mask_2 =
__halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
__half2 scale_mask_3 =
__halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
__half2 scale_mask_4 =
__halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
out_half2[0] =
__hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]);
out_half2[1] =
__hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]);
out_half2[2] =
__hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]);
out_half2[3] =
__hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]);
outs_float4[i] = out_float4;
}
template <>
void launch_ls_dropout_res_bias<float>(float *out, const float *vals,
uint8_t *mask, const float *bias,
const float *residual, int total_count,
int dim, float ratio,
cudaStream_t stream) {
int grid_dim = total_count >> 12;
ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, residual,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals,
uint8_t *mask, const __half *bias,
const __half *residual, int total_count,
int dim, float ratio,
cudaStream_t stream) {
int grid_dim = total_count >> 13;
ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, residual,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
/**
* @brief fused bias and dropout backward at the end of Attention and FFN
*
* @thread
* gridDim.x = hidden_size / 8
* blockDim.x = 8
* blockDim.y = 1024 / 8 = 128
*
* @param row_size batch_size * seq_len
* @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size
* @return void
*/
__global__ void ls_dropout_bias_bwd_kernel(
const int row_size, const float ratio, float *__restrict__ in_grad,
float *__restrict__ bias_grad, const float *__restrict__ out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
// every block generate 8 bias result
__shared__ float tile[8][129];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
int stride = hidden_size * 128;
float local_sum = 0;
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
for (int r = threadIdx.y; r < row_size; r += 128) {
float val = out_grad[idx];
val *= scale * static_cast<float>(mask[idx]);
local_sum += val;
in_grad[idx] = val;
idx += stride;
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
float sum = 0;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 7;
int y = tid & (127);
if (y < 32) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sum += tile[x][y + i * 32];
}
}
__syncthreads();
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
bias_grad[pos] = tile[0][threadIdx.x];
}
}
__global__ void ls_dropout_bias_bwd_kernel(
const int row_size, const float ratio, __half *__restrict__ in_grad,
__half *__restrict__ bias_grad, const __half *__restrict__ out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
__shared__ __half2 tile[8][129];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
__half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
__half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
int stride = hidden_size * 128;
__half2 local_sum = __float2half2_rn(0.f);
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
for (int r = threadIdx.y; r < row_size; r += 128) {
__half2 val = out_grad2[idx];
__half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
val *= scale * m2;
local_sum += val;
in_grad2[idx] = val;
idx += stride;
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
__half2 sum = __float2half2_rn(0.f);
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 7;
int y = tid & (127);
if (y < 32) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sum += tile[x][y + i * 32];
}
}
__syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
bias_grad2[pos] = tile[0][threadIdx.x];
}
}
template <typename T>
void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream) {
dim3 grid_dim((dim - 1) / 8 + 1);
dim3 block_dim(8, 128);
ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
}
template <>
void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad,
const __half *out_grad, const uint8_t *mask,
int row_size, int dim, float ratio,
cudaStream_t stream) {
dim >>= 1;
dim3 grid_dim((dim - 1) / 8 + 1);
dim3 block_dim(8, 128);
ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
}
template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad,
const float *out_grad,
const uint8_t *mask, int row_size,
int dim, float ratio,
cudaStream_t stream);
/**
* @brief fused bias, activation, and dropout at the end of first ffn
*
* @thread
* gridDim.x = hidden_size / 8
* blockDim.x = 8
* blockDim.y = 1024 / 8 = 128
*
* @tparam act_type activation function, like kRelu, kGelu
* @param total_count total elements
* @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias
* @param seed seed to curand
* @param hidden_size
* @return void
*/
template <ActivationType act_type>
__global__ void ls_dropout_act_bias_kernel(
const int total_count, const float ratio, float *__restrict__ out,
const float *__restrict__ in, uint8_t *__restrict__ mask,
const float *__restrict__ bias, const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int bias_i = i % (hidden_size >> 2);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
const float4 input4 = data4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
float4 output4;
output4.x =
activation_kernel<act_type, float>(input4.x + b4.x) * scale * m[0];
output4.y =
activation_kernel<act_type, float>(input4.y + b4.y) * scale * m[1];
output4.z =
activation_kernel<act_type, float>(input4.z + b4.z) * scale * m[2];
output4.w =
activation_kernel<act_type, float>(input4.w + b4.w) * scale * m[3];
out4[i] = output4;
}
template <ActivationType act_type>
__global__ void ls_dropout_act_bias_kernel(
const int total_count, const float ratio, __half *__restrict__ out,
const __half *__restrict__ in, uint8_t *__restrict__ mask,
const __half *__restrict__ bias, const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = (uint8_t)(rand.x > ratio);
m[5] = (uint8_t)(rand.y > ratio);
m[6] = (uint8_t)(rand.z > ratio);
m[7] = (uint8_t)(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = *m8;
int bias_i = i % (hidden_size >> 3);
float4 val_float4 = vals_float4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4);
__half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
__half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
__half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
__half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
out_half2[0] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[0], b_half2[0])),
scale_mask_1);
out_half2[1] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[1], b_half2[1])),
scale_mask_2);
out_half2[2] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[2], b_half2[2])),
scale_mask_3);
out_half2[3] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[3], b_half2[3])),
scale_mask_4);
outs_float4[i] = out_float4;
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kGelu, float>(
float *out, const float *vals, uint8_t *mask, const float *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 10;
ls_dropout_act_bias_kernel<ActivationType::kGelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kGelu, __half>(
__half *out, const __half *vals, uint8_t *mask, const __half *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 11;
ls_dropout_act_bias_kernel<ActivationType::kGelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kRelu, float>(
float *out, const float *vals, uint8_t *mask, const float *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 10;
ls_dropout_act_bias_kernel<ActivationType::kRelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kRelu, __half>(
__half *out, const __half *vals, uint8_t *mask, const __half *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 11;
ls_dropout_act_bias_kernel<ActivationType::kRelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
/**
* @brief fused bias, activation, and dropout backward
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @tparam act_type kRelu
* @param row_size batch_size * seq_len
* @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size
* @return void
*/
template <ActivationType act_type, typename T>
__global__ void ls_dropout_act_bias_bwd_kernel(
const int row_size, const float ratio, T *in_grad,
T *__restrict__ bias_grad, const T *__restrict__ input,
const T *__restrict__ bias, const T *out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
__shared__ float tile[WARP_SIZE][WARP_SIZE + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int stride = hidden_size * WARP_SIZE;
float local_sum = 0;
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
if (col_idx < hidden_size) {
for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
float val = out_grad[idx];
float in = input[idx];
float b = bias[idx % hidden_size];
val = activation_bwd_kernel<act_type, float>(
val * scale * static_cast<float>(mask[idx]), in + b);
local_sum += val;
in_grad[idx] = val;
idx += stride;
}
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
__syncthreads();
if (threadIdx.y == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
bias_grad[pos] = tile[0][threadIdx.x];
}
}
// @brief fused bias, activation, and dropout backward
// It is deprecated for precision reason. Keep it for future optimization.
//
// template <ActivationType act_type>
// __global__ void ls_dropout_act_bias_bwd_kernel(
// const int row_size, const float ratio, __half * in_grad,
// __half *__restrict__ bias_grad, const __half *__restrict__ input, const
// __half *__restrict__ bias, const __half * out_grad, const uint8_t
// *__restrict__ mask, const int hidden_size) {
// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1];
// cg::thread_block b = cg::this_thread_block();
// cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
// const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
// const __half2 *input2 = reinterpret_cast<const __half2 *>(input);
// const __half2 *bias2 = reinterpret_cast<const __half2 *>(bias);
// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// int stride = hidden_size * WARP_SIZE;
// __half2 local_sum = __float2half2_rn(0.f);
// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
// if (col_idx < hidden_size) {
// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
// __half2 val = out_grad2[idx];
// __half2 in2 = input2[idx];
// __half2 b2 = bias2[idx % hidden_size ];
// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
// val = activation_bwd_kernel<ActivationType::kRelu, __half2>(val * scale
// *
// m2,
// in2+b2);
// local_sum += val;
// in_grad2[idx] = val;
// idx += stride;
// }
// }
// tile[threadIdx.x][threadIdx.y] = local_sum;
// __syncthreads();
// __half2 sum = tile[threadIdx.y][threadIdx.x];
// __syncthreads();
// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
// __syncthreads();
// if (threadIdx.y == 0) {
// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// bias_grad2[pos] = tile[0][threadIdx.x];
// }
// }
template <ActivationType act_type, typename T>
void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
const T *bias, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream) {
dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
ls_dropout_act_bias_bwd_kernel<act_type><<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim);
}
// template <>
// void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
// __half *in_grad, __half *bias_grad,const __half *input, const __half
// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int
// dim, float ratio, cudaStream_t stream) {
// dim >>= 1;
// dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
// dim3 block_dim(WARP_SIZE, WARP_SIZE);
// ls_dropout_act_bias_bwd_kernel<ActivationType::kRelu>
// <<<grid_dim, block_dim, 0, stream>>>(row_size, ratio, in_grad,
// bias_grad,
// input, bias,out_grad, mask, dim);
// }
template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, float>(
float *in_grad, float *bias_grad, const float *input, const float *bias,
const float *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
__half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
const __half *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, float>(
float *in_grad, float *bias_grad, const float *input, const float *bias,
const float *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, __half>(
__half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
const __half *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
#include <chrono>
#include <ctime>
#include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate;
/**
* @brief element-wise activation function on device, like Relu, Gelu
*
* @tparam enum class ActivationType, kRelu, kGelu
* @tparam input type
* @param any shape of float and __half2
* @return same shape and type with input
*/
template <ActivationType, typename T>
__forceinline__ __device__ T activation_kernel(T x);
template <>
__device__ float activation_kernel<ActivationType::kGelu, float>(float x) {
float cdf =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf;
}
template <>
__device__ __half2
activation_kernel<ActivationType::kGelu, __half2>(__half2 val) {
__half2 val_pow3 = __hmul2(val, __hmul2(val, val));
float2 tmp_pow = __half22float2(val_pow3);
float2 tmp = __half22float2(val);
tmp.x =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
tmp.y =
0.5f *
(1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
return __hmul2(val, __float22half2_rn(tmp));
}
template <>
__device__ float activation_kernel<ActivationType::kRelu, float>(float x) {
return fmaxf(x, 0);
}
template <>
__device__ __half2
activation_kernel<ActivationType::kRelu, __half2>(__half2 x) {
return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
fmaxf(0.f, __half2float(x.y)));
}
/**
* @brief element-wise activation backward function on device
*
* @tparam enum class ActivationType
* @tparam input type
* @param any shape of float and __half2
* @return same shape of input
*/
template <ActivationType, typename T>
__forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
template <>
__device__ float activation_bwd_kernel<ActivationType::kGelu, float>(float grad,
float x) {
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return grad * (dg1 + dg2 + dg3);
}
template <>
__device__ __half activation_bwd_kernel<ActivationType::kGelu, __half>(
__half grad, __half x_half) {
float x = __half2float(x_half);
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return grad * __float2half(dg1 + dg2 + dg3);
}
template <>
__device__ float activation_bwd_kernel<ActivationType::kRelu, float>(float grad,
float x) {
return x > 0.f ? grad : 0.f;
}
template <>
__device__ __half
activation_bwd_kernel<ActivationType::kRelu, __half>(__half grad, __half x) {
const __half half_zero = __float2half(0.f);
return x > half_zero ? grad : half_zero;
}
template <>
__device__ __half2 activation_bwd_kernel<ActivationType::kRelu, __half2>(
__half2 grad2, __half2 x_half2) {
const __half half_zero = __float2half(0.f);
return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
x_half2.y > half_zero ? grad2.y : half_zero);
}
/**
* @brief init curand states in global memory
*
* @thread grid_dim * block*dim to suuport any size of states
* @param state persistant curand states
* @param seed seed to init states
* @return void
*/
__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
int seed) {
/* Each thread gets same seed, a different sequence
number, no offset */
int id = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(seed, id, 0, &state[id]);
}
void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
int grid_dim = total_count >> 9;
curand_init_kernel<<<grid_dim, 512, 0, stream>>>(
curandstate, std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
}
/**
* @brief element-wise dropout, store dropped position in mask, it's not
* in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param out any size of float and __half
* @param in same with out
* @param mask uint8 type, same size with out
* @param seed seed to curand
* @return void
*/
__global__ void ls_dropout_kernel(const int total_count, const float ratio,
float *__restrict__ out,
const float *__restrict__ in,
uint8_t *__restrict__ mask, const int seed) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
float4 input4 = data4[i];
float4 res4;
res4.x = input4.x * scale * m[0];
res4.y = input4.y * scale * m[1];
res4.z = input4.z * scale * m[2];
res4.w = input4.w * scale * m[3];
out4[i] = res4;
}
__global__ void ls_dropout_kernel(const int total_count, const float ratio,
__half *__restrict__ out,
const __half *__restrict__ in,
uint8_t *__restrict__ mask, const int seed) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = (uint8_t)(rand.x > ratio);
m[5] = (uint8_t)(rand.y > ratio);
m[6] = (uint8_t)(rand.z > ratio);
m[7] = (uint8_t)(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = *m8;
float4 val_float4 = vals_float4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
__half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
__half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
__half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
__half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
outs_float4[i] = out_float4;
}
/**
* @brief element-wise dropout backward with dropout mask, it's
* not in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param in any size of float and __half
* @param mask uint8 type, same size with in
* @return void
*/
__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
float *out, const float *in,
const uint8_t *__restrict__ mask) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *in4 = reinterpret_cast<const float4 *>(in);
const uint32_t *mask4 = reinterpret_cast<const uint32_t *>(mask);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
m4[0] = mask4[i];
float4 input4 = in4[i];
float4 res4;
res4.x = input4.x * scale * static_cast<float>(m[0]);
res4.y = input4.y * scale * static_cast<float>(m[1]);
res4.z = input4.z * scale * static_cast<float>(m[2]);
res4.w = input4.w * scale * static_cast<float>(m[3]);
out4[i] = res4;
}
__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
__half *out, const __half *in,
const uint8_t *__restrict__ mask) {
const __half scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
const uint64_t *mask8 = reinterpret_cast<const uint64_t *>(mask);
uint8_t m[8];
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
m8[0] = mask8[i];
float4 val_float4 = vals_float4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
__half2 scale_mask_1 =
__halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
__half2 scale_mask_2 =
__halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
__half2 scale_mask_3 =
__halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
__half2 scale_mask_4 =
__halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
out4[i] = out_float4;
}
template <>
void launch_ls_dropout<float>(float *out, const float *vals, uint8_t *mask,
int total_count, float ratio, cudaStream_t stream,
bool backward) {
int grid_dim = total_count >> 12;
if (!backward) {
ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
} else {
ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio,
out, vals, mask);
}
}
template <>
void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
int total_count, float ratio,
cudaStream_t stream, bool backward) {
int grid_dim = total_count >> 13;
if (!backward) {
ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count());
} else {
ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio,
out, vals, mask);
}
}
/**
* @brief fused bias, dropout, and residual at the end of Attention and FFN,
* store dropped position in mask, it's not in-place
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @param total_count total elements
* @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias
* @param residual [batch_size, seq_len, hidden_size], float and __half
* @param seed seed to curand
* @param hidden_size hidden size
* @return void
*/
__global__ void ls_dropout_res_bias_kernel(
const int total_count, const float ratio, float *__restrict__ out,
const float *__restrict__ in, uint8_t *__restrict__ mask,
const float *__restrict__ bias, const float *__restrict__ residual,
const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
const float4 *residual4 = reinterpret_cast<const float4 *>(residual);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = static_cast<uint8_t>(rand.x > ratio);
m[1] = static_cast<uint8_t>(rand.y > ratio);
m[2] = static_cast<uint8_t>(rand.z > ratio);
m[3] = static_cast<uint8_t>(rand.w > ratio);
int bias_i = i % (hidden_size >> 2);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
const float4 input4 = data4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
const float4 res4 = residual4[i];
float4 output4;
output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
out4[i] = output4;
}
__global__ void ls_dropout_res_bias_kernel(
const int total_count, const float ratio, __half *__restrict__ out,
const __half *__restrict__ in, uint8_t *__restrict__ mask,
const __half *__restrict__ bias, const __half *__restrict__ residual,
const int seed, const int hidden_size) {
const __half scale = 1. / (1. - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
const float4 *residual4 = reinterpret_cast<const float4 *>(residual);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = static_cast<uint8_t>(rand.x > ratio);
m[1] = static_cast<uint8_t>(rand.y > ratio);
m[2] = static_cast<uint8_t>(rand.z > ratio);
m[3] = static_cast<uint8_t>(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = static_cast<uint8_t>(rand.x > ratio);
m[5] = static_cast<uint8_t>(rand.y > ratio);
m[6] = static_cast<uint8_t>(rand.z > ratio);
m[7] = static_cast<uint8_t>(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = m8[0];
int bias_i = i % (hidden_size >> 3);
float4 val_float4 = vals_float4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
const float4 res4 = residual4[i];
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4);
const __half2 *res_half2 = reinterpret_cast<const __half2 *>(&res4);
__half2 scale_mask_1 =
__halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
__half2 scale_mask_2 =
__halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
__half2 scale_mask_3 =
__halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
__half2 scale_mask_4 =
__halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
out_half2[0] =
__hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]);
out_half2[1] =
__hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]);
out_half2[2] =
__hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]);
out_half2[3] =
__hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]);
outs_float4[i] = out_float4;
}
template <>
void launch_ls_dropout_res_bias<float>(float *out, const float *vals,
uint8_t *mask, const float *bias,
const float *residual, int total_count,
int dim, float ratio,
cudaStream_t stream) {
int grid_dim = total_count >> 12;
ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, residual,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals,
uint8_t *mask, const __half *bias,
const __half *residual, int total_count,
int dim, float ratio,
cudaStream_t stream) {
int grid_dim = total_count >> 13;
ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, residual,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
/**
* @brief fused bias and dropout backward at the end of Attention and FFN
*
* @thread
* gridDim.x = hidden_size / 8
* blockDim.x = 8
* blockDim.y = 1024 / 8 = 128
*
* @param row_size batch_size * seq_len
* @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size
* @return void
*/
__global__ void ls_dropout_bias_bwd_kernel(
const int row_size, const float ratio, float *__restrict__ in_grad,
float *__restrict__ bias_grad, const float *__restrict__ out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
// every block generate 8 bias result
__shared__ float tile[8][129];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
int stride = hidden_size * 128;
float local_sum = 0;
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
for (int r = threadIdx.y; r < row_size; r += 128) {
float val = out_grad[idx];
val *= scale * static_cast<float>(mask[idx]);
local_sum += val;
in_grad[idx] = val;
idx += stride;
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
float sum = 0;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 7;
int y = tid & (127);
if (y < 32) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sum += tile[x][y + i * 32];
}
}
__syncthreads();
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
bias_grad[pos] = tile[0][threadIdx.x];
}
}
__global__ void ls_dropout_bias_bwd_kernel(
const int row_size, const float ratio, __half *__restrict__ in_grad,
__half *__restrict__ bias_grad, const __half *__restrict__ out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
__shared__ __half2 tile[8][129];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
__half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
__half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
int stride = hidden_size * 128;
__half2 local_sum = __float2half2_rn(0.f);
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
for (int r = threadIdx.y; r < row_size; r += 128) {
__half2 val = out_grad2[idx];
__half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
val *= scale * m2;
local_sum += val;
in_grad2[idx] = val;
idx += stride;
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
__half2 sum = __float2half2_rn(0.f);
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 7;
int y = tid & (127);
if (y < 32) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sum += tile[x][y + i * 32];
}
}
__syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
bias_grad2[pos] = tile[0][threadIdx.x];
}
}
template <typename T>
void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream) {
dim3 grid_dim((dim - 1) / 8 + 1);
dim3 block_dim(8, 128);
ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
}
template <>
void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad,
const __half *out_grad, const uint8_t *mask,
int row_size, int dim, float ratio,
cudaStream_t stream) {
dim >>= 1;
dim3 grid_dim((dim - 1) / 8 + 1);
dim3 block_dim(8, 128);
ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
}
template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad,
const float *out_grad,
const uint8_t *mask, int row_size,
int dim, float ratio,
cudaStream_t stream);
/**
* @brief fused bias, activation, and dropout at the end of first ffn
*
* @thread
* gridDim.x = hidden_size / 8
* blockDim.x = 8
* blockDim.y = 1024 / 8 = 128
*
* @tparam act_type activation function, like kRelu, kGelu
* @param total_count total elements
* @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias
* @param seed seed to curand
* @param hidden_size
* @return void
*/
template <ActivationType act_type>
__global__ void ls_dropout_act_bias_kernel(
const int total_count, const float ratio, float *__restrict__ out,
const float *__restrict__ in, uint8_t *__restrict__ mask,
const float *__restrict__ bias, const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int bias_i = i % (hidden_size >> 2);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0];
const float4 input4 = data4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
float4 output4;
output4.x =
activation_kernel<act_type, float>(input4.x + b4.x) * scale * m[0];
output4.y =
activation_kernel<act_type, float>(input4.y + b4.y) * scale * m[1];
output4.z =
activation_kernel<act_type, float>(input4.z + b4.z) * scale * m[2];
output4.w =
activation_kernel<act_type, float>(input4.w + b4.w) * scale * m[3];
out4[i] = output4;
}
template <ActivationType act_type>
__global__ void ls_dropout_act_bias_kernel(
const int total_count, const float ratio, __half *__restrict__ out,
const __half *__restrict__ in, uint8_t *__restrict__ mask,
const __half *__restrict__ bias, const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8];
float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
rand = curand_uniform4(&state);
m[4] = (uint8_t)(rand.x > ratio);
m[5] = (uint8_t)(rand.y > ratio);
m[6] = (uint8_t)(rand.z > ratio);
m[7] = (uint8_t)(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = *m8;
int bias_i = i % (hidden_size >> 3);
float4 val_float4 = vals_float4[i];
const float4 b4 = __ldg(&bias4[bias_i]);
float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4);
__half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
__half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
__half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
__half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
out_half2[0] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[0], b_half2[0])),
scale_mask_1);
out_half2[1] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[1], b_half2[1])),
scale_mask_2);
out_half2[2] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[2], b_half2[2])),
scale_mask_3);
out_half2[3] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[3], b_half2[3])),
scale_mask_4);
outs_float4[i] = out_float4;
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kGelu, float>(
float *out, const float *vals, uint8_t *mask, const float *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 10;
ls_dropout_act_bias_kernel<ActivationType::kGelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kGelu, __half>(
__half *out, const __half *vals, uint8_t *mask, const __half *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 11;
ls_dropout_act_bias_kernel<ActivationType::kGelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kRelu, float>(
float *out, const float *vals, uint8_t *mask, const float *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 10;
ls_dropout_act_bias_kernel<ActivationType::kRelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
template <>
void launch_ls_dropout_act_bias<ActivationType::kRelu, __half>(
__half *out, const __half *vals, uint8_t *mask, const __half *bias,
int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 11;
ls_dropout_act_bias_kernel<ActivationType::kRelu>
<<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count(),
dim);
}
/**
* @brief fused bias, activation, and dropout backward
*
* @thread
* gridDim.x = total_count / 1024
* blockDim.x = 1024
*
* @tparam act_type kRelu
* @param row_size batch_size * seq_len
* @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size
* @return void
*/
template <ActivationType act_type, typename T>
__global__ void ls_dropout_act_bias_bwd_kernel(
const int row_size, const float ratio, T *in_grad,
T *__restrict__ bias_grad, const T *__restrict__ input,
const T *__restrict__ bias, const T *out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) {
const float scale = 1.f / (1.f - ratio);
__shared__ float tile[WARP_SIZE][WARP_SIZE + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int stride = hidden_size * WARP_SIZE;
float local_sum = 0;
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
if (col_idx < hidden_size) {
for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
float val = out_grad[idx];
float in = input[idx];
float b = bias[idx % hidden_size];
val = activation_bwd_kernel<act_type, float>(
val * scale * static_cast<float>(mask[idx]), in + b);
local_sum += val;
in_grad[idx] = val;
idx += stride;
}
}
tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads();
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
__syncthreads();
if (threadIdx.y == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
bias_grad[pos] = tile[0][threadIdx.x];
}
}
// @brief fused bias, activation, and dropout backward
// It is deprecated for precision reason. Keep it for future optimization.
//
// template <ActivationType act_type>
// __global__ void ls_dropout_act_bias_bwd_kernel(
// const int row_size, const float ratio, __half * in_grad,
// __half *__restrict__ bias_grad, const __half *__restrict__ input, const
// __half *__restrict__ bias, const __half * out_grad, const uint8_t
// *__restrict__ mask, const int hidden_size) {
// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1];
// cg::thread_block b = cg::this_thread_block();
// cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
// const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
// const __half2 *input2 = reinterpret_cast<const __half2 *>(input);
// const __half2 *bias2 = reinterpret_cast<const __half2 *>(bias);
// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// int stride = hidden_size * WARP_SIZE;
// __half2 local_sum = __float2half2_rn(0.f);
// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
// if (col_idx < hidden_size) {
// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
// __half2 val = out_grad2[idx];
// __half2 in2 = input2[idx];
// __half2 b2 = bias2[idx % hidden_size ];
// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
// val = activation_bwd_kernel<ActivationType::kRelu, __half2>(val * scale
// *
// m2,
// in2+b2);
// local_sum += val;
// in_grad2[idx] = val;
// idx += stride;
// }
// }
// tile[threadIdx.x][threadIdx.y] = local_sum;
// __syncthreads();
// __half2 sum = tile[threadIdx.y][threadIdx.x];
// __syncthreads();
// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
// __syncthreads();
// if (threadIdx.y == 0) {
// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// bias_grad2[pos] = tile[0][threadIdx.x];
// }
// }
template <ActivationType act_type, typename T>
void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
const T *bias, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream) {
dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
ls_dropout_act_bias_bwd_kernel<act_type><<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim);
}
// template <>
// void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
// __half *in_grad, __half *bias_grad,const __half *input, const __half
// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int
// dim, float ratio, cudaStream_t stream) {
// dim >>= 1;
// dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
// dim3 block_dim(WARP_SIZE, WARP_SIZE);
// ls_dropout_act_bias_bwd_kernel<ActivationType::kRelu>
// <<<grid_dim, block_dim, 0, stream>>>(row_size, ratio, in_grad,
// bias_grad,
// input, bias,out_grad, mask, dim);
// }
template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, float>(
float *in_grad, float *bias_grad, const float *input, const float *bias,
const float *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
__half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
const __half *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, float>(
float *in_grad, float *bias_grad, const float *input, const float *bias,
const float *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, __half>(
__half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
const __half *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
#include <cooperative_groups.h>
#include "kernels.h"
namespace cg = cooperative_groups;
/**
@brief: fuse_transpose_bias
Calculate the sum of elements in each column of the matrix.
@thread
gridDim.x = ceil(cols / WARP_SIZE)
blockDim.x = WARP_SIZE
blockDim.y = WARP_SIZE
@param
inp: [rows, cols]
out: [cols]
rows: the number of rows in the matrix
cols: the number of cols in the matrix
*/
template <typename T>
__global__ void column_sum_reduce(const T *__restrict__ inp,
T *__restrict__ out, int rows, int cols) {
__shared__ float tile[WARP_SIZE][WARP_SIZE];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int y_stride = cols * WARP_SIZE;
float localSum = 0;
// Loop across matrix row
// TODO: optimize to log complexity
if (idx < cols) {
int offset = flat_2dim(threadIdx.y, idx, cols);
for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
// The sum of a row in tile is equal to the sum of a col in original matrix
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
// The change of threadIdx.x is continuous
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
// Calculate the sum of a row in tile
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
if (pos < cols) out[pos] = sum;
}
}
// [r, c] -> [c]
template <>
void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<float>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<__half>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
/**
@brief: fused_add2
Add two matrix inp1 and inp2 to out.
@thread
gridDim.x = batch_size * seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
inp1: [batch_size, seq_len, hidden_dim]
inp2: [batch_size, seq_len, hidden_dim]
out: [batch_size, seq_len, hidden_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
*/
template <typename T>
__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
int hidden_dim);
template <>
__global__ void fused_add2_kernel<float>(float *out, const float *inp1,
const float *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
val.x = vinp1.x + vinp2.x;
val.y = vinp1.y + vinp2.y;
val.z = vinp1.z + vinp2.z;
val.w = vinp1.w + vinp2.w;
out_4[offset + i] = val;
}
}
template <>
__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
const __half *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
__half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1);
__half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2);
__half2 *h2_val = reinterpret_cast<__half2 *>(&val);
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]);
h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]);
h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]);
h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]);
out_4[offset + i] = val;
}
}
//[b, s, h] -> [b, s, h]
template <>
void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2,
int batch_size, int seq_len, int hidden_dim,
cudaStream_t &stream) {
hidden_dim >>= 2;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <>
void launch_fused_add2<__half>(__half *out, const __half *inp1,
const __half *inp2, int batch_size, int seq_len,
int hidden_dim, cudaStream_t &stream) {
hidden_dim >>= 3;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <typename T>
__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output,
int sz0, int sz2, int sz1_1, int sz1_2) {
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x);
if (idx >= nele) {
return;
}
float4 *dst_ptr = (float4 *)output + idx;
int idx2 = idx % sz2;
idx = idx / sz2;
int idx1 = idx % (sz1_1 + sz1_2);
int idx0 = idx / (sz1_1 + sz1_2);
float4 *src_ptr = nullptr;
int sz1 = 0;
if (idx1 < sz1_1) {
sz1 = sz1_1;
src_ptr = (float4 *)inp1;
} else {
idx1 -= sz1_1;
sz1 = sz1_2;
src_ptr = (float4 *)inp2;
}
src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2);
dst_ptr[0] = src_ptr[0];
}
template <>
void launch_concat3_dim1<float>(const float *inp1, const float *inp2,
float *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 2;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}
template <>
void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2,
__half *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 3;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}
#include <cooperative_groups.h>
#include "kernels.h"
namespace cg = cooperative_groups;
/**
@brief: fuse_transpose_bias
Calculate the sum of elements in each column of the matrix.
@thread
gridDim.x = ceil(cols / WARP_SIZE)
blockDim.x = WARP_SIZE
blockDim.y = WARP_SIZE
@param
inp: [rows, cols]
out: [cols]
rows: the number of rows in the matrix
cols: the number of cols in the matrix
*/
template <typename T>
__global__ void column_sum_reduce(const T *__restrict__ inp,
T *__restrict__ out, int rows, int cols) {
__shared__ float tile[WARP_SIZE][WARP_SIZE];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int y_stride = cols * WARP_SIZE;
float localSum = 0;
// Loop across matrix row
// TODO: optimize to log complexity
if (idx < cols) {
int offset = flat_2dim(threadIdx.y, idx, cols);
for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
// The sum of a row in tile is equal to the sum of a col in original matrix
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
// The change of threadIdx.x is continuous
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
// Calculate the sum of a row in tile
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
if (pos < cols) out[pos] = sum;
}
}
// [r, c] -> [c]
template <>
void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<float>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<__half>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
/**
@brief: fused_add2
Add two matrix inp1 and inp2 to out.
@thread
gridDim.x = batch_size * seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
inp1: [batch_size, seq_len, hidden_dim]
inp2: [batch_size, seq_len, hidden_dim]
out: [batch_size, seq_len, hidden_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
*/
template <typename T>
__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
int hidden_dim);
template <>
__global__ void fused_add2_kernel<float>(float *out, const float *inp1,
const float *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
val.x = vinp1.x + vinp2.x;
val.y = vinp1.y + vinp2.y;
val.z = vinp1.z + vinp2.z;
val.w = vinp1.w + vinp2.w;
out_4[offset + i] = val;
}
}
template <>
__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
const __half *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
__half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1);
__half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2);
__half2 *h2_val = reinterpret_cast<__half2 *>(&val);
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]);
h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]);
h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]);
h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]);
out_4[offset + i] = val;
}
}
//[b, s, h] -> [b, s, h]
template <>
void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2,
int batch_size, int seq_len, int hidden_dim,
cudaStream_t &stream) {
hidden_dim >>= 2;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <>
void launch_fused_add2<__half>(__half *out, const __half *inp1,
const __half *inp2, int batch_size, int seq_len,
int hidden_dim, cudaStream_t &stream) {
hidden_dim >>= 3;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <typename T>
__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output,
int sz0, int sz2, int sz1_1, int sz1_2) {
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x);
if (idx >= nele) {
return;
}
float4 *dst_ptr = (float4 *)output + idx;
int idx2 = idx % sz2;
idx = idx / sz2;
int idx1 = idx % (sz1_1 + sz1_2);
int idx0 = idx / (sz1_1 + sz1_2);
float4 *src_ptr = nullptr;
int sz1 = 0;
if (idx1 < sz1_1) {
sz1 = sz1_1;
src_ptr = (float4 *)inp1;
} else {
idx1 -= sz1_1;
sz1 = sz1_2;
src_ptr = (float4 *)inp2;
}
src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2);
dst_ptr[0] = src_ptr[0];
}
template <>
void launch_concat3_dim1<float>(const float *inp1, const float *inp2,
float *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 2;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}
template <>
void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2,
__half *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 3;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <string>
#include "kernels.h"
template <typename T>
class Dropout {
public:
struct Config {
float ratio;
bool training;
Config(float r) : ratio(r), training(true) {}
float RATIO() const { return training ? ratio : 0.0; }
};
Dropout(const Config &config, size_t max_ele_num)
: _config(config), _mask(nullptr) {
_mask = cuda_malloc<uint8_t>(max_ele_num);
}
virtual ~Dropout() { cuda_free(_mask); }
// after attention softmax
void dropout(T *output, const T *input, int count, cudaStream_t stream,
bool bwd = false) {
launch_ls_dropout<T>(output, input, _mask, count, _config.RATIO(), stream,
bwd);
}
void d_dropout(T *d_inp_out, int count, cudaStream_t stream) {
launch_ls_dropout<T>(d_inp_out, d_inp_out, _mask, count, _config.RATIO(),
stream, true);
}
// transformer layer's postprocessing dropout, after attn or ffn module,
// before residual add.
void bias_dropout_residual(T *output, const T *input, const T *residual,
const T *bias, int rows, int cols,
cudaStream_t stream) {
launch_ls_dropout_res_bias<T>(output, input, _mask, bias, residual,
rows * cols, cols, _config.RATIO(), stream);
}
void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output,
int rows, int cols, cudaStream_t stream) {
launch_ls_dropout_bias_bwd<T>(d_input, d_bias, d_output, _mask, rows, cols,
_config.RATIO(), stream);
}
// dropout inside ffn.
void bias_act_dropout(T *output, const T *input, const T *bias, int rows,
int cols, std::string activation_fn,
cudaStream_t stream) {
if (activation_fn == "relu") {
launch_ls_dropout_act_bias<ActivationType::kRelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream);
} else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias<ActivationType::kGelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream);
} else {
throw std::runtime_error("not supported activation: " + activation_fn);
}
}
void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input,
const T *bias, int rows, int cols,
std::string activation_fn, cudaStream_t stream) {
if (activation_fn == "relu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream);
} else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream);
} else {
throw std::runtime_error("not supported activation: " + activation_fn);
}
}
bool HasDropout() const { return _config.RATIO() > 0.0; }
void SetTrainingMode(bool training) { _config.training = training; }
private:
uint8_t *_mask;
Config _config;
};
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <string>
#include "kernels.h"
template <typename T>
class Dropout {
public:
struct Config {
float ratio;
bool training;
Config(float r) : ratio(r), training(true) {}
float RATIO() const { return training ? ratio : 0.0; }
};
Dropout(const Config &config, size_t max_ele_num)
: _config(config), _mask(nullptr) {
_mask = cuda_malloc<uint8_t>(max_ele_num);
}
virtual ~Dropout() { cuda_free(_mask); }
// after attention softmax
void dropout(T *output, const T *input, int count, cudaStream_t stream,
bool bwd = false) {
launch_ls_dropout<T>(output, input, _mask, count, _config.RATIO(), stream,
bwd);
}
void d_dropout(T *d_inp_out, int count, cudaStream_t stream) {
launch_ls_dropout<T>(d_inp_out, d_inp_out, _mask, count, _config.RATIO(),
stream, true);
}
// transformer layer's postprocessing dropout, after attn or ffn module,
// before residual add.
void bias_dropout_residual(T *output, const T *input, const T *residual,
const T *bias, int rows, int cols,
cudaStream_t stream) {
launch_ls_dropout_res_bias<T>(output, input, _mask, bias, residual,
rows * cols, cols, _config.RATIO(), stream);
}
void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output,
int rows, int cols, cudaStream_t stream) {
launch_ls_dropout_bias_bwd<T>(d_input, d_bias, d_output, _mask, rows, cols,
_config.RATIO(), stream);
}
// dropout inside ffn.
void bias_act_dropout(T *output, const T *input, const T *bias, int rows,
int cols, std::string activation_fn,
cudaStream_t stream) {
if (activation_fn == "relu") {
launch_ls_dropout_act_bias<ActivationType::kRelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream);
} else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias<ActivationType::kGelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream);
} else {
throw std::runtime_error("not supported activation: " + activation_fn);
}
}
void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input,
const T *bias, int rows, int cols,
std::string activation_fn, cudaStream_t stream) {
if (activation_fn == "relu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream);
} else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream);
} else {
throw std::runtime_error("not supported activation: " + activation_fn);
}
}
bool HasDropout() const { return _config.RATIO() > 0.0; }
void SetTrainingMode(bool training) { _config.training = training; }
private:
uint8_t *_mask;
Config _config;
};
......@@ -3,10 +3,11 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <stdexcept>
#include <stdio.h>
#include <stdlib.h>
#include <stdexcept>
#define MAX_THREADS 1024
#define WARP_SIZE 32
......@@ -132,8 +133,9 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
}
/* Convert 4-dim tensor index into vector index */
__forceinline__ __host__ __device__ int
flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) {
__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3,
int id4, int dim2, int dim3,
int dim4) {
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
int res = id4;
......@@ -201,9 +203,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
}
/* Convert vector index to 6-dim tensor index */
__forceinline__ __host__ __device__ void
decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5,
int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) {
__forceinline__ __host__ __device__ void decompose_6dim(
int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0,
int *id1, int *id2, int *id3, int *id4, int *id5) {
*id5 = src % dim5;
src /= dim5;
......@@ -221,9 +223,11 @@ decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5,
}
/* Convert vector index to 5-dim tensor index */
__forceinline__ __host__ __device__ void
decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0,
int *id1, int *id2, int *id3, int *id4) {
__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1,
int dim2, int dim3,
int dim4, int *id0,
int *id1, int *id2,
int *id3, int *id4) {
*id4 = src % dim4;
src /= dim4;
......@@ -253,8 +257,9 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
}
/* Convert vector index to 3-dim tensor index */
__forceinline__ __host__ __device__ void
decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) {
__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1,
int dim2, int *id0,
int *id1, int *id2) {
*id2 = src % dim2;
src /= dim2;
......
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <fstream>
#include "kernels.h"
using namespace std;
template <typename T> class Normalize_Layer {
public:
struct Config {
uint32_t hidden_dim;
bool use_mean;
Config(uint32_t hidden_dim, bool use_mean = false)
: hidden_dim(hidden_dim), use_mean(use_mean) {}
};
Normalize_Layer(Config config, size_t max_rows)
: config_(config), vars_(nullptr), means_(nullptr) {
vars_ = cuda_malloc<T>(max_rows);
if (config_.use_mean) {
means_ = cuda_malloc<T>(max_rows);
}
}
~Normalize_Layer() {
cuda_free(vars_);
cuda_free(means_);
}
void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta,
int batch_size, cudaStream_t stream) {
launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size,
config_.hidden_dim, stream);
}
/*
residual_grad, inp_or_out, betta should be treated carefully.
inp_or_out = input if use_mean else output
residual_grad, betta can be nullptr.
residual_grad will be added to dinp if it is not nullptr
which is useful in transformer layer when pre-ln
betta are only used to compute xhat,
(use_mean == false) ^ (betta == nullptr) should be true
*/
void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out, const T *gamma,
const T *betta, int batch_size, cudaStream_t stream[2]) {
launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad,
inp_or_out, gamma, betta, vars_, means_, batch_size,
config_.hidden_dim, stream);
}
inline bool use_mean() const { return config_.use_mean; }
private:
Config config_;
T *vars_;
T *means_;
};
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <fstream>
#include "kernels.h"
using namespace std;
template <typename T>
class Normalize_Layer {
public:
struct Config {
uint32_t hidden_dim;
bool use_mean;
Config(uint32_t hidden_dim, bool use_mean = false)
: hidden_dim(hidden_dim), use_mean(use_mean) {}
};
Normalize_Layer(Config config, size_t max_rows)
: config_(config), vars_(nullptr), means_(nullptr) {
vars_ = cuda_malloc<T>(max_rows);
if (config_.use_mean) {
means_ = cuda_malloc<T>(max_rows);
}
}
~Normalize_Layer() {
cuda_free(vars_);
cuda_free(means_);
}
void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta,
int batch_size, cudaStream_t stream) {
launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size,
config_.hidden_dim, stream);
}
/*
residual_grad, inp_or_out, betta should be treated carefully.
inp_or_out = input if use_mean else output
residual_grad, betta can be nullptr.
residual_grad will be added to dinp if it is not nullptr
which is useful in transformer layer when pre-ln
betta are only used to compute xhat,
(use_mean == false) ^ (betta == nullptr) should be true
*/
void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out, const T *gamma,
const T *betta, int batch_size, cudaStream_t stream[2]) {
launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad,
inp_or_out, gamma, betta, vars_, means_, batch_size,
config_.hidden_dim, stream);
}
inline bool use_mean() const { return config_.use_mean; }
private:
Config config_;
T *vars_;
T *means_;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment