Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -15,7 +15,7 @@ from .kvcache_manager import MemoryManager ...@@ -15,7 +15,7 @@ from .kvcache_manager import MemoryManager
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] _supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"]
class TPInferEngine: class TPInferEngine:
...@@ -39,14 +39,16 @@ class TPInferEngine: ...@@ -39,14 +39,16 @@ class TPInferEngine:
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs) >>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
""" """
def __init__(self, def __init__(
model: nn.Module, self,
shard_config: ShardConfig, model: nn.Module,
max_batch_size: int, shard_config: ShardConfig,
max_input_len: int, max_batch_size: int,
max_output_len: int, max_input_len: int,
dtype: torch.dtype = torch.float16, max_output_len: int,
device: str = 'cuda') -> None: dtype: torch.dtype = torch.float16,
device: str = "cuda",
) -> None:
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_input_len = max_input_len self.max_input_len = max_input_len
self.max_output_len = max_output_len self.max_output_len = max_output_len
...@@ -63,7 +65,7 @@ class TPInferEngine: ...@@ -63,7 +65,7 @@ class TPInferEngine:
self.head_num = model.config.num_attention_heads self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers self.layer_num = model.config.num_hidden_layers
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None self.cache_manager = None
self.shard_config = shard_config self.shard_config = shard_config
...@@ -74,9 +76,10 @@ class TPInferEngine: ...@@ -74,9 +76,10 @@ class TPInferEngine:
def _init_manager(self) -> None: def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads self.head_num //= self.tp_size # update sharded number of heads
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.cache_manager = MemoryManager(
self.layer_num) self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
)
def _optimize_model(self, model: nn.Module) -> None: def _optimize_model(self, model: nn.Module) -> None:
""" """
...@@ -90,7 +93,7 @@ class TPInferEngine: ...@@ -90,7 +93,7 @@ class TPInferEngine:
self._shard_model_by(shardformer, model) self._shard_model_by(shardformer, model)
def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
""" Prepare the engine with a given ShardConfig. """Prepare the engine with a given ShardConfig.
Args: Args:
shard_config (ShardConfig): shard config given to specify settings of the engine. shard_config (ShardConfig): shard config given to specify settings of the engine.
...@@ -118,9 +121,10 @@ class TPInferEngine: ...@@ -118,9 +121,10 @@ class TPInferEngine:
return shard_config return shard_config
def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
""" Shard original model by the given ShardFormer and store the sharded model. """ """Shard original model by the given ShardFormer and store the sharded model."""
assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ assert (
"Discrepancy between the tp size of TPInferEngine and the tp size of shard config" self.tp_size == shardformer.shard_config.tensor_parallel_size
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__ model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True) policy = get_autopolicy(model, inference_only=True)
...@@ -147,7 +151,7 @@ class TPInferEngine: ...@@ -147,7 +151,7 @@ class TPInferEngine:
for t in input_tokens: for t in input_tokens:
if torch.is_tensor(input_tokens[t]): if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].cuda() input_tokens[t] = input_tokens[t].cuda()
if 'max_new_tokens' not in generate_kwargs: if "max_new_tokens" not in generate_kwargs:
generate_kwargs.update(max_new_tokens=self.max_output_len) generate_kwargs.update(max_new_tokens=self.max_output_len)
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)
...@@ -176,18 +180,18 @@ class TPInferEngine: ...@@ -176,18 +180,18 @@ class TPInferEngine:
attention_mask = None attention_mask = None
if isinstance(inputs, (BatchEncoding, dict)): if isinstance(inputs, (BatchEncoding, dict)):
input_ids_list = inputs['input_ids'] input_ids_list = inputs["input_ids"]
attention_mask = inputs['attention_mask'] attention_mask = inputs["attention_mask"]
else: else:
input_ids_list = inputs input_ids_list = inputs
if isinstance(input_ids_list[0], int): # for a single input if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list] input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list) batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda') seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda') seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0 start_index = 0
max_len_in_batch = -1 max_len_in_batch = -1
...@@ -210,10 +214,10 @@ class TPInferEngine: ...@@ -210,10 +214,10 @@ class TPInferEngine:
seq_start_indexes[i] = start_index seq_start_indexes[i] = start_index
start_index += curr_seq_len start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda') block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to('cuda') batch_infer_state.seq_len = seq_lengths.to("cuda")
batch_infer_state.start_loc = seq_start_indexes.to('cuda') batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0 batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0 batch_infer_state.past_key_values_len = 0
...@@ -248,7 +252,7 @@ class TPInferEngine: ...@@ -248,7 +252,7 @@ class TPInferEngine:
model = self.model.model model = self.model.model
elif isinstance(model, BloomForCausalLM): elif isinstance(model, BloomForCausalLM):
model = self.model.transformer model = self.model.transformer
setattr(model, 'infer_state', batch_infer_state) setattr(model, "infer_state", batch_infer_state)
outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
...@@ -262,14 +266,15 @@ class TPInferEngine: ...@@ -262,14 +266,15 @@ class TPInferEngine:
# as an arg into model.forward. # as an arg into model.forward.
# It requires rewriting model generate and replacing model forward. # It requires rewriting model generate and replacing model forward.
@torch.no_grad() @torch.no_grad()
def _generate_by_pass_infer_state(self, def _generate_by_pass_infer_state(
input_tokens, self,
max_out_length: int, input_tokens,
generation_config: Optional[GenerationConfig] = None, max_out_length: int,
stopping_criteria: Optional[StoppingCriteriaList] = None, generation_config: Optional[GenerationConfig] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
**model_kwargs) -> torch.Tensor: prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
raise NotImplementedError("generate by passing BatchInferState is not implemented.") raise NotImplementedError("generate by passing BatchInferState is not implemented.")
# might want to use in rewritten generate method: use after model.forward # might want to use in rewritten generate method: use after model.forward
......
...@@ -19,13 +19,15 @@ class MemoryManager: ...@@ -19,13 +19,15 @@ class MemoryManager:
device: device used to store the key and value cache device: device used to store the key and value cache
""" """
def __init__(self, def __init__(
size: int, self,
dtype: torch.dtype, size: int,
head_num: int, dtype: torch.dtype,
head_dim: int, head_num: int,
layer_num: int, head_dim: int,
device: torch.device = torch.device('cuda')): layer_num: int,
device: torch.device = torch.device("cuda"),
):
self.logger = logging.get_logger(__name__) self.logger = logging.get_logger(__name__)
self.available_size = size self.available_size = size
self.past_key_values_length = 0 self.past_key_values_length = 0
...@@ -33,13 +35,13 @@ class MemoryManager: ...@@ -33,13 +35,13 @@ class MemoryManager:
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
def _init_mem_states(self, size, device): def _init_mem_states(self, size, device):
""" Initialize tensors used to manage memory states """ """Initialize tensors used to manage memory states"""
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
self.indexes = torch.arange(0, size, dtype=torch.long, device=device) self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
""" Initialize key buffer and value buffer on specified device """ """Initialize key buffer and value buffer on specified device"""
self.key_buffer = [ self.key_buffer = [
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
] ]
...@@ -49,10 +51,9 @@ class MemoryManager: ...@@ -49,10 +51,9 @@ class MemoryManager:
@torch.no_grad() @torch.no_grad()
def alloc(self, required_size): def alloc(self, required_size):
""" allocate space of required_size by providing indexes representing available physical spaces """ """allocate space of required_size by providing indexes representing available physical spaces"""
if required_size > self.available_size: if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} " self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
f"left_size {self.available_size}")
return None return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
...@@ -63,23 +64,25 @@ class MemoryManager: ...@@ -63,23 +64,25 @@ class MemoryManager:
@torch.no_grad() @torch.no_grad()
def alloc_contiguous(self, required_size): def alloc_contiguous(self, required_size):
""" allocate contiguous space of required_size """ """allocate contiguous space of required_size"""
if required_size > self.available_size: if required_size > self.available_size:
self.logger.warning(f"No enough cache: required_size {required_size} " self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
f"left_size {self.available_size}")
return None return None
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
sum_size = len(self.mem_cum_sum) sum_size = len(self.mem_cum_sum)
loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size + loc_sums = (
1] + self.mem_state[0:sum_size - self.mem_cum_sum[required_size - 1 :]
required_size + 1] - self.mem_cum_sum[0 : sum_size - required_size + 1]
can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] + self.mem_state[0 : sum_size - required_size + 1]
)
can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
if can_used_loc.shape[0] == 0: if can_used_loc.shape[0] == 0:
self.logger.info(f"No enough contiguous cache: required_size {required_size} " self.logger.info(
f"left_size {self.available_size}") f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
)
return None return None
start_loc = can_used_loc[0] start_loc = can_used_loc[0]
select_index = self.indexes[start_loc:start_loc + required_size] select_index = self.indexes[start_loc : start_loc + required_size]
self.mem_state[select_index] = 0 self.mem_state[select_index] = 0
self.available_size -= len(select_index) self.available_size -= len(select_index)
start = start_loc.item() start = start_loc.item()
...@@ -88,13 +91,13 @@ class MemoryManager: ...@@ -88,13 +91,13 @@ class MemoryManager:
@torch.no_grad() @torch.no_grad()
def free(self, free_index): def free(self, free_index):
""" free memory by updating memory states based on given indexes """ """free memory by updating memory states based on given indexes"""
self.available_size += free_index.shape[0] self.available_size += free_index.shape[0]
self.mem_state[free_index] = 1 self.mem_state[free_index] = 1
@torch.no_grad() @torch.no_grad()
def free_all(self): def free_all(self):
""" free all memory by updating memory states """ """free all memory by updating memory states"""
self.available_size = len(self.mem_state) self.available_size = len(self.mem_state)
self.mem_state[:] = 1 self.mem_state[:] = 1
self.past_key_values_length = 0 self.past_key_values_length = 0
......
from .bloom import BloomInferenceForwards from .bloom import BloomInferenceForwards
from .llama import LlamaInferenceForwards from .llama import LlamaInferenceForwards
__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards'] __all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"]
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -31,17 +31,17 @@ def generate_alibi(n_head, dtype=torch.float16): ...@@ -31,17 +31,17 @@ def generate_alibi(n_head, dtype=torch.float16):
""" """
def get_slopes_power_of_2(n): def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3))) start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * start**i for i in range(n)] return [start * start**i for i in range(n)]
def get_slopes(n): def get_slopes(n):
if math.log2(n).is_integer(): if math.log2(n).is_integer():
return get_slopes_power_of_2(n) return get_slopes_power_of_2(n)
else: else:
closest_power_of_2 = 2**math.floor(math.log2(n)) closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
slopes_double = get_slopes(2 * closest_power_of_2) slopes_double = get_slopes(2 * closest_power_of_2)
slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2] slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
return slopes_combined return slopes_combined
slopes = get_slopes(n_head) slopes = get_slopes(n_head)
...@@ -72,7 +72,6 @@ class BloomInferenceForwards: ...@@ -72,7 +72,6 @@ class BloomInferenceForwards:
infer_state: Optional[BatchInferState] = None, infer_state: Optional[BatchInferState] = None,
**deprecated_arguments, **deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False: if deprecated_arguments.pop("position_ids", False) is not False:
...@@ -86,8 +85,9 @@ class BloomInferenceForwards: ...@@ -86,8 +85,9 @@ class BloomInferenceForwards:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states output_hidden_states = (
if output_hidden_states is not None else self.config.output_hidden_states) output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -122,14 +122,15 @@ class BloomInferenceForwards: ...@@ -122,14 +122,15 @@ class BloomInferenceForwards:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False use_cache = False
# NOTE determine if BatchInferState is passed in via arg # NOTE determine if BatchInferState is passed in via arg
# if not, get the attr binded to the model # if not, get the attr binded to the model
# We might wantto remove setattr later # We might wantto remove setattr later
if infer_state is None: if infer_state is None:
assert hasattr(self, 'infer_state') assert hasattr(self, "infer_state")
infer_state = self.infer_state infer_state = self.infer_state
# Compute alibi tensor: check build_alibi_tensor documentation # Compute alibi tensor: check build_alibi_tensor documentation
...@@ -146,10 +147,11 @@ class BloomInferenceForwards: ...@@ -146,10 +147,11 @@ class BloomInferenceForwards:
if use_cache and seq_length != 1: if use_cache and seq_length != 1:
# prefill stage # prefill stage
infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, BatchInferState.init_block_loc(
infer_state.context_mem_index) infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else: else:
infer_state.is_context_stage = False infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
...@@ -182,8 +184,11 @@ class BloomInferenceForwards: ...@@ -182,8 +184,11 @@ class BloomInferenceForwards:
# alibi = generate_alibi(self.num_heads).contiguous().cuda() # alibi = generate_alibi(self.num_heads).contiguous().cuda()
tp_size = dist.get_world_size() tp_size = dist.get_world_size()
curr_tp_rank = dist.get_rank() curr_tp_rank = dist.get_rank()
alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) * alibi = (
self.num_heads].cuda() generate_alibi(self.num_heads * tp_size)
.contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
.cuda()
)
causal_mask = self._prepare_attn_mask( causal_mask = self._prepare_attn_mask(
attention_mask, attention_mask,
input_shape=(batch_size, seq_length), input_shape=(batch_size, seq_length),
...@@ -197,7 +202,6 @@ class BloomInferenceForwards: ...@@ -197,7 +202,6 @@ class BloomInferenceForwards:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# NOTE: currently our KV cache manager does not handle this condition # NOTE: currently our KV cache manager does not handle this condition
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
...@@ -250,32 +254,34 @@ class BloomInferenceForwards: ...@@ -250,32 +254,34 @@ class BloomInferenceForwards:
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=presents, # should always be (None, None, ..., None) past_key_values=presents, # should always be (None, None, ..., None)
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
) )
@staticmethod @staticmethod
def bloom_for_causal_lm_forward(self: BloomForCausalLM, def bloom_for_causal_lm_forward(
input_ids: Optional[torch.LongTensor] = None, self: BloomForCausalLM,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
head_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None, labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
infer_state: Optional[BatchInferState] = None, return_dict: Optional[bool] = None,
**deprecated_arguments): infer_state: Optional[BatchInferState] = None,
**deprecated_arguments,
):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
""" """
logger = logging.get_logger(__name__) logging.get_logger(__name__)
if deprecated_arguments.pop("position_ids", False) is not False: if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
...@@ -289,17 +295,19 @@ class BloomInferenceForwards: ...@@ -289,17 +295,19 @@ class BloomInferenceForwards:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer, transformer_outputs = BloomInferenceForwards.bloom_model_forward(
input_ids, self.transformer,
past_key_values=past_key_values, input_ids,
attention_mask=attention_mask, past_key_values=past_key_values,
head_mask=head_mask, attention_mask=attention_mask,
inputs_embeds=inputs_embeds, head_mask=head_mask,
use_cache=use_cache, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, use_cache=use_cache,
output_hidden_states=output_hidden_states, output_attentions=output_attentions,
return_dict=return_dict, output_hidden_states=output_hidden_states,
infer_state=infer_state) return_dict=return_dict,
infer_state=infer_state,
)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -314,8 +322,9 @@ class BloomInferenceForwards: ...@@ -314,8 +322,9 @@ class BloomInferenceForwards:
batch_size, seq_length, vocab_size = shift_logits.shape batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), loss = loss_fct(
shift_labels.view(batch_size * seq_length)) shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
...@@ -353,11 +362,13 @@ class BloomInferenceForwards: ...@@ -353,11 +362,13 @@ class BloomInferenceForwards:
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
model_inputs.update({ model_inputs.update(
"past_key_values": past_key_values, {
"use_cache": kwargs.get("use_cache"), "past_key_values": past_key_values,
"attention_mask": attention_mask, "use_cache": kwargs.get("use_cache"),
}) "attention_mask": attention_mask,
}
)
return model_inputs return model_inputs
@staticmethod @staticmethod
...@@ -416,7 +427,7 @@ class BloomInferenceForwards: ...@@ -416,7 +427,7 @@ class BloomInferenceForwards:
else: else:
outputs = (output,) + outputs[1:] outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions return outputs # hidden_states, present, attentions
@staticmethod @staticmethod
def bloom_attention_forward( def bloom_attention_forward(
...@@ -431,20 +442,19 @@ class BloomInferenceForwards: ...@@ -431,20 +442,19 @@ class BloomInferenceForwards:
output_attentions: bool = False, output_attentions: bool = False,
infer_state: Optional[BatchInferState] = None, infer_state: Optional[BatchInferState] = None,
): ):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim] # 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, H, D_HEAD = query_layer.shape batch_size, q_length, H, D_HEAD = query_layer.shape
k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
mem_manager = infer_state.cache_manager mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id layer_id = infer_state.decode_layer_id
if layer_id == 0: # once per model.forward if layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_length # += 1 infer_state.cache_manager.past_key_values_length += q_length # += 1
if infer_state.is_context_stage: if infer_state.is_context_stage:
# context process # context process
...@@ -471,9 +481,11 @@ class BloomInferenceForwards: ...@@ -471,9 +481,11 @@ class BloomInferenceForwards:
if infer_state.decode_is_contiguous: if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[layer_id][ cache_k = infer_state.cache_manager.key_buffer[layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[layer_id][ cache_v = infer_state.cache_manager.value_buffer[layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(k) cache_k.copy_(k)
cache_v.copy_(v) cache_v.copy_(v)
else: else:
...@@ -486,8 +498,17 @@ class BloomInferenceForwards: ...@@ -486,8 +498,17 @@ class BloomInferenceForwards:
b_loc = infer_state.block_loc b_loc = infer_state.block_loc
b_seq_len = infer_state.seq_len b_seq_len = infer_state.seq_len
output = torch.empty_like(q) output = torch.empty_like(q)
token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc, token_attention_fwd(
b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi) q,
mem_manager.key_buffer[layer_id],
mem_manager.value_buffer[layer_id],
output,
b_loc,
b_start_loc,
b_seq_len,
infer_state.cache_manager.past_key_values_length,
alibi,
)
context_layer = output.view(batch_size, q_length, H * D_HEAD) context_layer = output.view(batch_size, q_length, H * D_HEAD)
...@@ -504,8 +525,8 @@ class BloomInferenceForwards: ...@@ -504,8 +525,8 @@ class BloomInferenceForwards:
output_tensor = torch.zeros_like(context_layer) output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp): for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear( output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)], context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)], self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
) )
else: else:
output_tensor = self.dense(context_layer) output_tensor = self.dense(context_layer)
......
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np
import torch import torch
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
...@@ -15,6 +14,7 @@ from colossalai.kernel.triton import ( ...@@ -15,6 +14,7 @@ from colossalai.kernel.triton import (
try: try:
from vllm import layernorm_ops, pos_encoding_ops from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True HAS_VLLM_KERNERL = True
...@@ -29,17 +29,17 @@ except: ...@@ -29,17 +29,17 @@ except:
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:] x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
...@@ -71,8 +71,7 @@ class LlamaInferenceForwards: ...@@ -71,8 +71,7 @@ class LlamaInferenceForwards:
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
): ):
batch_size = input_ids.shape[0] # input_ids.shape[0]
batch_size = input_ids.shape[0] # input_ids.shape[0]
infer_state = self.infer_state infer_state = self.infer_state
...@@ -103,10 +102,11 @@ class LlamaInferenceForwards: ...@@ -103,10 +102,11 @@ class LlamaInferenceForwards:
if use_cache and seq_length != 1: if use_cache and seq_length != 1:
# NOTE assuem prefill stage # NOTE assuem prefill stage
# allocate memory block # allocate memory block
infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.init_block_loc(
infer_state.context_mem_index) infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else: else:
infer_state.is_context_stage = False infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
...@@ -129,20 +129,20 @@ class LlamaInferenceForwards: ...@@ -129,20 +129,20 @@ class LlamaInferenceForwards:
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length, position_ids = torch.arange(
seq_length + past_key_values_length, past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
dtype=torch.long, )
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
if infer_state.is_context_stage: if infer_state.is_context_stage:
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1) position_ids.view(-1).shape[0], -1
)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1) position_ids.view(-1).shape[0], -1
)
else: else:
seq_len = infer_state.seq_len seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
...@@ -153,12 +153,13 @@ class LlamaInferenceForwards: ...@@ -153,12 +153,13 @@ class LlamaInferenceForwards:
# embed positions # embed positions
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), attention_mask = torch.ones(
dtype=torch.bool, (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
device=inputs_embeds.device) )
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, attention_mask = self._prepare_decoder_attention_mask(
past_key_values_length) attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -216,7 +217,6 @@ class LlamaInferenceForwards: ...@@ -216,7 +217,6 @@ class LlamaInferenceForwards:
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None, infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -261,7 +261,6 @@ class LlamaInferenceForwards: ...@@ -261,7 +261,6 @@ class LlamaInferenceForwards:
use_cache: bool = False, use_cache: bool = False,
infer_state: Optional[BatchInferState] = None, infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert use_cache is True, "use_cache should be set to True using this llama attention" assert use_cache is True, "use_cache should be set to True using this llama attention"
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -277,8 +276,8 @@ class LlamaInferenceForwards: ...@@ -277,8 +276,8 @@ class LlamaInferenceForwards:
# NOTE might want to revise # NOTE might want to revise
# need some way to record the length of past key values cache # need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now # since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
...@@ -299,38 +298,62 @@ class LlamaInferenceForwards: ...@@ -299,38 +298,62 @@ class LlamaInferenceForwards:
# first token generation # first token generation
# copy key and value calculated in current step to memory manager # copy key and value calculated in current step to memory manager
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, _copy_kv_to_mem_cache(
infer_state.cache_manager) infer_state.decode_layer_id,
key_states,
value_states,
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc, llama_context_attn_fwd(
infer_state.seq_len, infer_state.cache_manager.past_key_values_length) query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else: else:
if infer_state.decode_is_contiguous: if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(key_states) cache_k.copy_(key_states)
cache_v.copy_(value_states) cache_v.copy_(value_states)
else: else:
# if decode is not contiguous, use triton kernel to copy key and value cache # if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, _copy_kv_to_mem_cache(
infer_state.decode_mem_index, infer_state.cache_manager) infer_state.decode_layer_id,
key_states,
value_states,
infer_state.decode_mem_index,
infer_state.cache_manager,
)
# second token and follows # second token and follows
# kv = torch.stack((key_states, value_states), dim=2) # kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim) # (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states) attn_output = torch.empty_like(query_states)
token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], token_attention_fwd(
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, query_states,
infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.past_key_values_length) infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
...@@ -341,7 +364,6 @@ class LlamaInferenceForwards: ...@@ -341,7 +364,6 @@ class LlamaInferenceForwards:
def get_llama_vllm_rmsnorm_forward(): def get_llama_vllm_rmsnorm_forward():
if HAS_VLLM_KERNERL: if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
......
from .bloom import BloomModelInferPolicy from .bloom import BloomModelInferPolicy
from .llama import LlamaModelInferPolicy from .llama import LlamaModelInferPolicy
__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy'] __all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"]
...@@ -9,6 +9,7 @@ from ..modeling.bloom import BloomInferenceForwards ...@@ -9,6 +9,7 @@ from ..modeling.bloom import BloomInferenceForwards
try: try:
from colossalai.kernel.triton import layer_norm from colossalai.kernel.triton import layer_norm
HAS_TRITON_NORM = True HAS_TRITON_NORM = True
except: except:
print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
...@@ -27,40 +28,40 @@ def get_triton_layernorm_forward(): ...@@ -27,40 +28,40 @@ def get_triton_layernorm_forward():
class BloomModelInferPolicy(BloomForCausalLMPolicy): class BloomModelInferPolicy(BloomForCausalLMPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy() policy = super().module_policy()
# NOTE set inference mode to shard config # NOTE set inference mode to shard config
self.shard_config._infer() self.shard_config._infer()
method_replacement = { method_replacement = {
'forward': BloomInferenceForwards.bloom_for_causal_lm_forward, "forward": BloomInferenceForwards.bloom_for_causal_lm_forward,
'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation,
} }
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=BloomForCausalLM
target_key=BloomForCausalLM) )
method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward} method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward} method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward} method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=BloomAttention
target_key=BloomAttention) )
if HAS_TRITON_NORM: if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward() infer_method = get_triton_layernorm_forward()
method_replacement = {'forward': partial(infer_method)} method_replacement = {"forward": partial(infer_method)}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=LayerNorm
target_key=LayerNorm) )
return policy return policy
...@@ -10,6 +10,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw ...@@ -10,6 +10,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forw
try: try:
from colossalai.kernel.triton import rmsnorm_forward from colossalai.kernel.triton import rmsnorm_forward
HAS_TRITON_RMSNORM = True HAS_TRITON_RMSNORM = True
except: except:
print("you should install triton from https://github.com/openai/triton") print("you should install triton from https://github.com/openai/triton")
...@@ -28,7 +29,6 @@ def get_triton_rmsnorm_forward(): ...@@ -28,7 +29,6 @@ def get_triton_rmsnorm_forward():
class LlamaModelInferPolicy(LlamaForCausalLMPolicy): class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -37,20 +37,20 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): ...@@ -37,20 +37,20 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
self.shard_config._infer() self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)} method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
method_replacement = {'forward': partial(infer_forward)} method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
target_key=LlamaDecoderLayer) )
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
method_replacement = {'forward': partial(infer_forward)} method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=LlamaAttention
target_key=LlamaAttention) )
infer_forward = None infer_forward = None
if HAS_TRITON_RMSNORM: if HAS_TRITON_RMSNORM:
...@@ -60,9 +60,9 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): ...@@ -60,9 +60,9 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
infer_forward = get_llama_vllm_rmsnorm_forward() infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None: if infer_forward is not None:
method_replacement = {'forward': partial(infer_forward)} method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, self.append_or_create_method_replacement(
policy=policy, description=method_replacement, policy=policy, target_key=LlamaRMSNorm
target_key=LlamaRMSNorm) )
return policy return policy
...@@ -14,15 +14,17 @@ from colossalai.logging import get_dist_logger ...@@ -14,15 +14,17 @@ from colossalai.logging import get_dist_logger
from colossalai.utils import set_device, set_seed from colossalai.utils import set_device, set_seed
def launch(config: Union[str, Path, Config, Dict], def launch(
rank: int, config: Union[str, Path, Config, Dict],
world_size: int, rank: int,
host: str, world_size: int,
port: int, host: str,
backend: str = 'nccl', port: int,
local_rank: int = None, backend: str = "nccl",
seed: int = 1024, local_rank: int = None,
verbose: bool = True): seed: int = 1024,
verbose: bool = True,
):
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions. arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
...@@ -46,7 +48,7 @@ def launch(config: Union[str, Path, Config, Dict], ...@@ -46,7 +48,7 @@ def launch(config: Union[str, Path, Config, Dict],
warnings.warn("`config` is deprecated and will be removed soon.") warnings.warn("`config` is deprecated and will be removed soon.")
# init default process group # init default process group
init_method = f'tcp://[{host}]:{port}' init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device # set cuda device
...@@ -58,15 +60,17 @@ def launch(config: Union[str, Path, Config, Dict], ...@@ -58,15 +60,17 @@ def launch(config: Union[str, Path, Config, Dict],
if verbose: if verbose:
logger = get_dist_logger() logger = get_dist_logger()
logger.info(f'Distributed environment is initialized, world size: {dist.get_world_size()}', ranks=[0]) logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
def launch_from_slurm(config: Union[str, Path, Config, Dict], def launch_from_slurm(
host: str, config: Union[str, Path, Config, Dict],
port: int, host: str,
backend: str = 'nccl', port: int,
seed: int = 1024, backend: str = "nccl",
verbose: bool = True): seed: int = 1024,
verbose: bool = True,
):
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM set by SLURM
...@@ -79,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict], ...@@ -79,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True. verbose (bool, optional): Whether to print logs. Defaults to True.
""" """
try: try:
rank = int(os.environ['SLURM_PROCID']) rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ['SLURM_NPROCS']) world_size = int(os.environ["SLURM_NPROCS"])
except KeyError as e: except KeyError as e:
raise RuntimeError( raise RuntimeError(
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
) )
launch(config=config, launch(
rank=rank, config=config,
world_size=world_size, rank=rank,
host=host, world_size=world_size,
port=port, host=host,
backend=backend, port=port,
seed=seed, backend=backend,
verbose=verbose) seed=seed,
verbose=verbose,
)
def launch_from_openmpi(config: Union[str, Path, Config, Dict],
host: str,
port: int, def launch_from_openmpi(
backend: str = 'nccl', config: Union[str, Path, Config, Dict],
seed: int = 1024, host: str,
verbose: bool = True): port: int,
backend: str = "nccl",
seed: int = 1024,
verbose: bool = True,
):
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI set by OpenMPI
...@@ -114,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict], ...@@ -114,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True. verbose (bool, optional): Whether to print logs. Defaults to True.
""" """
try: try:
rank = int(os.environ['OMPI_COMM_WORLD_RANK']) rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
except KeyError as e: except KeyError as e:
raise RuntimeError( raise RuntimeError(
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
) )
launch(config=config, launch(
local_rank=local_rank, config=config,
rank=rank, local_rank=local_rank,
world_size=world_size, rank=rank,
host=host, world_size=world_size,
port=port, host=host,
backend=backend, port=port,
seed=seed, backend=backend,
verbose=verbose) seed=seed,
verbose=verbose,
)
def launch_from_torch(config: Union[str, Path, Config, Dict],
backend: str = 'nccl',
seed: int = 1024, def launch_from_torch(
verbose: bool = True): config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch from the environment variables set by PyTorch
...@@ -147,22 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], ...@@ -147,22 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True. verbose (bool, optional): Whether to print logs. Defaults to True.
""" """
try: try:
rank = int(os.environ['RANK']) rank = int(os.environ["RANK"])
local_rank = int(os.environ['LOCAL_RANK']) local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ['WORLD_SIZE']) world_size = int(os.environ["WORLD_SIZE"])
host = os.environ['MASTER_ADDR'] host = os.environ["MASTER_ADDR"]
port = int(os.environ['MASTER_PORT']) port = int(os.environ["MASTER_PORT"])
except KeyError as e: except KeyError as e:
raise RuntimeError( raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
) )
launch(config=config, launch(
local_rank=local_rank, config=config,
rank=rank, local_rank=local_rank,
world_size=world_size, rank=rank,
host=host, world_size=world_size,
port=port, host=host,
backend=backend, port=port,
seed=seed, backend=backend,
verbose=verbose) seed=seed,
verbose=verbose,
)
from .model import AMPModelMixin, ModelWrapper from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin'] __all__ = ["OptimizerWrapper", "ModelWrapper", "AMPModelMixin"]
...@@ -26,11 +26,9 @@ class ModelWrapper(nn.Module): ...@@ -26,11 +26,9 @@ class ModelWrapper(nn.Module):
class AMPModelMixin: class AMPModelMixin:
"""This mixin class defines the interface for AMP training. """This mixin class defines the interface for AMP training."""
"""
def update_master_params(self): def update_master_params(self):
""" """
Update the master parameters for AMP training. Update the master parameters for AMP training.
""" """
pass
...@@ -22,7 +22,7 @@ class OptimizerWrapper: ...@@ -22,7 +22,7 @@ class OptimizerWrapper:
params = [] params = []
for group in self.param_groups: for group in self.param_groups:
params += group['params'] params += group["params"]
return params return params
@property @property
...@@ -82,12 +82,14 @@ class OptimizerWrapper: ...@@ -82,12 +82,14 @@ class OptimizerWrapper:
""" """
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs) nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
def clip_grad_by_norm(self, def clip_grad_by_norm(
max_norm: Union[float, int], self,
norm_type: Union[float, int] = 2.0, max_norm: Union[float, int],
error_if_nonfinite: bool = False, norm_type: Union[float, int] = 2.0,
*args, error_if_nonfinite: bool = False,
**kwargs) -> Tensor: *args,
**kwargs,
) -> Tensor:
""" """
Clips gradient norm of an iterable of parameters. Clips gradient norm of an iterable of parameters.
...@@ -113,7 +115,8 @@ class OptimizerWrapper: ...@@ -113,7 +115,8 @@ class OptimizerWrapper:
loss (Tensor): The loss to be scaled. loss (Tensor): The loss to be scaled.
""" """
raise NotImplementedError( raise NotImplementedError(
"The method scale_loss is only available for optimizers with mixed precision training") "The method scale_loss is only available for optimizers with mixed precision training"
)
def unscale_grad(self): def unscale_grad(self):
""" """
...@@ -122,7 +125,8 @@ class OptimizerWrapper: ...@@ -122,7 +125,8 @@ class OptimizerWrapper:
Note: Only available for optimizers with mixed precision training. Note: Only available for optimizers with mixed precision training.
""" """
raise NotImplementedError( raise NotImplementedError(
"The method unscale_grad is only available for optimizers with mixed precision training") "The method unscale_grad is only available for optimizers with mixed precision training"
)
def unwrap(self): def unwrap(self):
""" """
......
...@@ -4,6 +4,10 @@ from .multihead_attention import MultiHeadAttention ...@@ -4,6 +4,10 @@ from .multihead_attention import MultiHeadAttention
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
__all__ = [ __all__ = [
'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention', "LayerNorm",
'AttnMaskType' "MultiHeadAttention",
"FusedScaleMaskSoftmax",
"ScaledUpperTriangMaskedSoftmax",
"ColoAttention",
"AttnMaskType",
] ]
...@@ -7,4 +7,4 @@ ...@@ -7,4 +7,4 @@
#define DATA_PTR data_ptr #define DATA_PTR data_ptr
#else #else
#define DATA_PTR data #define DATA_PTR data
#endif #endif
\ No newline at end of file
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/reduce.h> #include <thrust/reduce.h>
#include "cuda_util.h" #include "cuda_util.h"
/* GPU function guard */ /* GPU function guard */
......
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate; curandStatePhilox4_32_10_t *curandstate;
/** /**
* @brief element-wise activation function on device, like Relu, Gelu * @brief element-wise activation function on device, like Relu, Gelu
* *
* @tparam enum class ActivationType, kRelu, kGelu * @tparam enum class ActivationType, kRelu, kGelu
* @tparam input type * @tparam input type
* @param any shape of float and __half2 * @param any shape of float and __half2
* @return same shape and type with input * @return same shape and type with input
*/ */
template <ActivationType, typename T> template <ActivationType, typename T>
__forceinline__ __device__ T activation_kernel(T x); __forceinline__ __device__ T activation_kernel(T x);
template <> template <>
__device__ float activation_kernel<ActivationType::kGelu, float>(float x) { __device__ float activation_kernel<ActivationType::kGelu, float>(float x) {
float cdf = float cdf =
0.5f * 0.5f *
(1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf; return x * cdf;
} }
template <> template <>
__device__ __half2 __device__ __half2
activation_kernel<ActivationType::kGelu, __half2>(__half2 val) { activation_kernel<ActivationType::kGelu, __half2>(__half2 val) {
__half2 val_pow3 = __hmul2(val, __hmul2(val, val)); __half2 val_pow3 = __hmul2(val, __hmul2(val, val));
float2 tmp_pow = __half22float2(val_pow3); float2 tmp_pow = __half22float2(val_pow3);
float2 tmp = __half22float2(val); float2 tmp = __half22float2(val);
tmp.x = tmp.x =
0.5f * 0.5f *
(1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
tmp.y = tmp.y =
0.5f * 0.5f *
(1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
return __hmul2(val, __float22half2_rn(tmp)); return __hmul2(val, __float22half2_rn(tmp));
} }
template <> template <>
__device__ float activation_kernel<ActivationType::kRelu, float>(float x) { __device__ float activation_kernel<ActivationType::kRelu, float>(float x) {
return fmaxf(x, 0); return fmaxf(x, 0);
} }
template <> template <>
__device__ __half2 __device__ __half2
activation_kernel<ActivationType::kRelu, __half2>(__half2 x) { activation_kernel<ActivationType::kRelu, __half2>(__half2 x) {
return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
fmaxf(0.f, __half2float(x.y))); fmaxf(0.f, __half2float(x.y)));
} }
/** /**
* @brief element-wise activation backward function on device * @brief element-wise activation backward function on device
* *
* @tparam enum class ActivationType * @tparam enum class ActivationType
* @tparam input type * @tparam input type
* @param any shape of float and __half2 * @param any shape of float and __half2
* @return same shape of input * @return same shape of input
*/ */
template <ActivationType, typename T> template <ActivationType, typename T>
__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); __forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
template <> template <>
__device__ float activation_bwd_kernel<ActivationType::kGelu, float>(float grad, __device__ float activation_bwd_kernel<ActivationType::kGelu, float>(float grad,
float x) { float x) {
const float sqrt_param = 0.79788456080286535587989211986876f; const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715; const float mul_param = 0.044715;
float x2mul = x * x * mul_param; float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul)); float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h); float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul; float dg3 = dg2 * 3 * x2mul;
return grad * (dg1 + dg2 + dg3); return grad * (dg1 + dg2 + dg3);
} }
template <> template <>
__device__ __half activation_bwd_kernel<ActivationType::kGelu, __half>( __device__ __half activation_bwd_kernel<ActivationType::kGelu, __half>(
__half grad, __half x_half) { __half grad, __half x_half) {
float x = __half2float(x_half); float x = __half2float(x_half);
const float sqrt_param = 0.79788456080286535587989211986876f; const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715; const float mul_param = 0.044715;
float x2mul = x * x * mul_param; float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul)); float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h); float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul; float dg3 = dg2 * 3 * x2mul;
return grad * __float2half(dg1 + dg2 + dg3); return grad * __float2half(dg1 + dg2 + dg3);
} }
template <> template <>
__device__ float activation_bwd_kernel<ActivationType::kRelu, float>(float grad, __device__ float activation_bwd_kernel<ActivationType::kRelu, float>(float grad,
float x) { float x) {
return x > 0.f ? grad : 0.f; return x > 0.f ? grad : 0.f;
} }
template <> template <>
__device__ __half __device__ __half
activation_bwd_kernel<ActivationType::kRelu, __half>(__half grad, __half x) { activation_bwd_kernel<ActivationType::kRelu, __half>(__half grad, __half x) {
const __half half_zero = __float2half(0.f); const __half half_zero = __float2half(0.f);
return x > half_zero ? grad : half_zero; return x > half_zero ? grad : half_zero;
} }
template <> template <>
__device__ __half2 activation_bwd_kernel<ActivationType::kRelu, __half2>( __device__ __half2 activation_bwd_kernel<ActivationType::kRelu, __half2>(
__half2 grad2, __half2 x_half2) { __half2 grad2, __half2 x_half2) {
const __half half_zero = __float2half(0.f); const __half half_zero = __float2half(0.f);
return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
x_half2.y > half_zero ? grad2.y : half_zero); x_half2.y > half_zero ? grad2.y : half_zero);
} }
/** /**
* @brief init curand states in global memory * @brief init curand states in global memory
* *
* @thread grid_dim * block*dim to suuport any size of states * @thread grid_dim * block*dim to suuport any size of states
* @param state persistant curand states * @param state persistant curand states
* @param seed seed to init states * @param seed seed to init states
* @return void * @return void
*/ */
__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, __global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
int seed) { int seed) {
/* Each thread gets same seed, a different sequence /* Each thread gets same seed, a different sequence
number, no offset */ number, no offset */
int id = threadIdx.x + blockIdx.x * blockDim.x; int id = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(seed, id, 0, &state[id]); curand_init(seed, id, 0, &state[id]);
} }
void launch_curand_init(int total_count, int dim, cudaStream_t stream) { void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
int grid_dim = total_count >> 9; int grid_dim = total_count >> 9;
curand_init_kernel<<<grid_dim, 512, 0, stream>>>( curand_init_kernel<<<grid_dim, 512, 0, stream>>>(
curandstate, std::chrono::duration_cast<std::chrono::microseconds>( curandstate, std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count()); .count());
} }
/** /**
* @brief element-wise dropout, store dropped position in mask, it's not * @brief element-wise dropout, store dropped position in mask, it's not
* in-place * in-place
* *
* @thread * @thread
* gridDim.x = total_count / 1024 * gridDim.x = total_count / 1024
* blockDim.x = 1024 * blockDim.x = 1024
* *
* @param total_count total elements * @param total_count total elements
* @param ratio drop ratio * @param ratio drop ratio
* @param out any size of float and __half * @param out any size of float and __half
* @param in same with out * @param in same with out
* @param mask uint8 type, same size with out * @param mask uint8 type, same size with out
* @param seed seed to curand * @param seed seed to curand
* @return void * @return void
*/ */
__global__ void ls_dropout_kernel(const int total_count, const float ratio, __global__ void ls_dropout_kernel(const int total_count, const float ratio,
float *__restrict__ out, float *__restrict__ out,
const float *__restrict__ in, const float *__restrict__ in,
uint8_t *__restrict__ mask, const int seed) { uint8_t *__restrict__ mask, const int seed) {
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
uint8_t m[4]; uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in); const float4 *data4 = reinterpret_cast<const float4 *>(in);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask); uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio); m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio); m[3] = (uint8_t)(rand.w > ratio);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m); uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0]; mask4[i] = m4[0];
float4 input4 = data4[i]; float4 input4 = data4[i];
float4 res4; float4 res4;
res4.x = input4.x * scale * m[0]; res4.x = input4.x * scale * m[0];
res4.y = input4.y * scale * m[1]; res4.y = input4.y * scale * m[1];
res4.z = input4.z * scale * m[2]; res4.z = input4.z * scale * m[2];
res4.w = input4.w * scale * m[3]; res4.w = input4.w * scale * m[3];
out4[i] = res4; out4[i] = res4;
} }
__global__ void ls_dropout_kernel(const int total_count, const float ratio, __global__ void ls_dropout_kernel(const int total_count, const float ratio,
__half *__restrict__ out, __half *__restrict__ out,
const __half *__restrict__ in, const __half *__restrict__ in,
uint8_t *__restrict__ mask, const int seed) { uint8_t *__restrict__ mask, const int seed) {
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in); const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out); float4 *outs_float4 = reinterpret_cast<float4 *>(out);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask); uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8]; uint8_t m[8];
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio); m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio); m[3] = (uint8_t)(rand.w > ratio);
rand = curand_uniform4(&state); rand = curand_uniform4(&state);
m[4] = (uint8_t)(rand.x > ratio); m[4] = (uint8_t)(rand.x > ratio);
m[5] = (uint8_t)(rand.y > ratio); m[5] = (uint8_t)(rand.y > ratio);
m[6] = (uint8_t)(rand.z > ratio); m[6] = (uint8_t)(rand.z > ratio);
m[7] = (uint8_t)(rand.w > ratio); m[7] = (uint8_t)(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m); uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = *m8; mask8[i] = *m8;
float4 val_float4 = vals_float4[i]; float4 val_float4 = vals_float4[i];
float4 out_float4; float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
__half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
__half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
__half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
__half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
out_half2[0] = __hmul2(val_half2[0], scale_mask_1); out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
out_half2[1] = __hmul2(val_half2[1], scale_mask_2); out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
out_half2[2] = __hmul2(val_half2[2], scale_mask_3); out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
out_half2[3] = __hmul2(val_half2[3], scale_mask_4); out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
outs_float4[i] = out_float4; outs_float4[i] = out_float4;
} }
/** /**
* @brief element-wise dropout backward with dropout mask, it's * @brief element-wise dropout backward with dropout mask, it's
* not in-place * not in-place
* *
* @thread * @thread
* gridDim.x = total_count / 1024 * gridDim.x = total_count / 1024
* blockDim.x = 1024 * blockDim.x = 1024
* *
* @param total_count total elements * @param total_count total elements
* @param ratio drop ratio * @param ratio drop ratio
* @param in any size of float and __half * @param in any size of float and __half
* @param mask uint8 type, same size with in * @param mask uint8 type, same size with in
* @return void * @return void
*/ */
__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
float *out, const float *in, float *out, const float *in,
const uint8_t *__restrict__ mask) { const uint8_t *__restrict__ mask) {
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count) return;
uint8_t m[4]; uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *in4 = reinterpret_cast<const float4 *>(in); const float4 *in4 = reinterpret_cast<const float4 *>(in);
const uint32_t *mask4 = reinterpret_cast<const uint32_t *>(mask); const uint32_t *mask4 = reinterpret_cast<const uint32_t *>(mask);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m); uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
m4[0] = mask4[i]; m4[0] = mask4[i];
float4 input4 = in4[i]; float4 input4 = in4[i];
float4 res4; float4 res4;
res4.x = input4.x * scale * static_cast<float>(m[0]); res4.x = input4.x * scale * static_cast<float>(m[0]);
res4.y = input4.y * scale * static_cast<float>(m[1]); res4.y = input4.y * scale * static_cast<float>(m[1]);
res4.z = input4.z * scale * static_cast<float>(m[2]); res4.z = input4.z * scale * static_cast<float>(m[2]);
res4.w = input4.w * scale * static_cast<float>(m[3]); res4.w = input4.w * scale * static_cast<float>(m[3]);
out4[i] = res4; out4[i] = res4;
} }
__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
__half *out, const __half *in, __half *out, const __half *in,
const uint8_t *__restrict__ mask) { const uint8_t *__restrict__ mask) {
const __half scale = 1.f / (1.f - ratio); const __half scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count) return;
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in); const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
const uint64_t *mask8 = reinterpret_cast<const uint64_t *>(mask); const uint64_t *mask8 = reinterpret_cast<const uint64_t *>(mask);
uint8_t m[8]; uint8_t m[8];
uint64_t *m8 = reinterpret_cast<uint64_t *>(m); uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
m8[0] = mask8[i]; m8[0] = mask8[i];
float4 val_float4 = vals_float4[i]; float4 val_float4 = vals_float4[i];
float4 out_float4; float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
__half2 scale_mask_1 = __half2 scale_mask_1 =
__halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
__half2 scale_mask_2 = __half2 scale_mask_2 =
__halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
__half2 scale_mask_3 = __half2 scale_mask_3 =
__halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
__half2 scale_mask_4 = __half2 scale_mask_4 =
__halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
out_half2[0] = __hmul2(val_half2[0], scale_mask_1); out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
out_half2[1] = __hmul2(val_half2[1], scale_mask_2); out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
out_half2[2] = __hmul2(val_half2[2], scale_mask_3); out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
out_half2[3] = __hmul2(val_half2[3], scale_mask_4); out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
out4[i] = out_float4; out4[i] = out_float4;
} }
template <> template <>
void launch_ls_dropout<float>(float *out, const float *vals, uint8_t *mask, void launch_ls_dropout<float>(float *out, const float *vals, uint8_t *mask,
int total_count, float ratio, cudaStream_t stream, int total_count, float ratio, cudaStream_t stream,
bool backward) { bool backward) {
int grid_dim = total_count >> 12; int grid_dim = total_count >> 12;
if (!backward) { if (!backward) {
ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>( ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, total_count, ratio, out, vals, mask,
std::chrono::duration_cast<std::chrono::microseconds>( std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count()); .count());
} else { } else {
ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio, ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio,
out, vals, mask); out, vals, mask);
} }
} }
template <> template <>
void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
int total_count, float ratio, int total_count, float ratio,
cudaStream_t stream, bool backward) { cudaStream_t stream, bool backward) {
int grid_dim = total_count >> 13; int grid_dim = total_count >> 13;
if (!backward) { if (!backward) {
ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>( ls_dropout_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, total_count, ratio, out, vals, mask,
std::chrono::duration_cast<std::chrono::microseconds>( std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count()); .count());
} else { } else {
ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio, ls_dropout_bwd_kernel<<<grid_dim + 1, 1024, 0, stream>>>(total_count, ratio,
out, vals, mask); out, vals, mask);
} }
} }
/** /**
* @brief fused bias, dropout, and residual at the end of Attention and FFN, * @brief fused bias, dropout, and residual at the end of Attention and FFN,
* store dropped position in mask, it's not in-place * store dropped position in mask, it's not in-place
* *
* @thread * @thread
* gridDim.x = total_count / 1024 * gridDim.x = total_count / 1024
* blockDim.x = 1024 * blockDim.x = 1024
* *
* @param total_count total elements * @param total_count total elements
* @param ratio drop ratio * @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half * @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half * @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type * @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias * @param bias [hidden_size], ffn bias
* @param residual [batch_size, seq_len, hidden_size], float and __half * @param residual [batch_size, seq_len, hidden_size], float and __half
* @param seed seed to curand * @param seed seed to curand
* @param hidden_size hidden size * @param hidden_size hidden size
* @return void * @return void
*/ */
__global__ void ls_dropout_res_bias_kernel( __global__ void ls_dropout_res_bias_kernel(
const int total_count, const float ratio, float *__restrict__ out, const int total_count, const float ratio, float *__restrict__ out,
const float *__restrict__ in, uint8_t *__restrict__ mask, const float *__restrict__ in, uint8_t *__restrict__ mask,
const float *__restrict__ bias, const float *__restrict__ residual, const float *__restrict__ bias, const float *__restrict__ residual,
const int seed, const int hidden_size) { const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
uint8_t m[4]; uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in); const float4 *data4 = reinterpret_cast<const float4 *>(in);
const float4 *residual4 = reinterpret_cast<const float4 *>(residual); const float4 *residual4 = reinterpret_cast<const float4 *>(residual);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias); const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask); uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
m[0] = static_cast<uint8_t>(rand.x > ratio); m[0] = static_cast<uint8_t>(rand.x > ratio);
m[1] = static_cast<uint8_t>(rand.y > ratio); m[1] = static_cast<uint8_t>(rand.y > ratio);
m[2] = static_cast<uint8_t>(rand.z > ratio); m[2] = static_cast<uint8_t>(rand.z > ratio);
m[3] = static_cast<uint8_t>(rand.w > ratio); m[3] = static_cast<uint8_t>(rand.w > ratio);
int bias_i = i % (hidden_size >> 2); int bias_i = i % (hidden_size >> 2);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m); uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0]; mask4[i] = m4[0];
const float4 input4 = data4[i]; const float4 input4 = data4[i];
const float4 b4 = __ldg(&bias4[bias_i]); const float4 b4 = __ldg(&bias4[bias_i]);
const float4 res4 = residual4[i]; const float4 res4 = residual4[i];
float4 output4; float4 output4;
output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
out4[i] = output4; out4[i] = output4;
} }
__global__ void ls_dropout_res_bias_kernel( __global__ void ls_dropout_res_bias_kernel(
const int total_count, const float ratio, __half *__restrict__ out, const int total_count, const float ratio, __half *__restrict__ out,
const __half *__restrict__ in, uint8_t *__restrict__ mask, const __half *__restrict__ in, uint8_t *__restrict__ mask,
const __half *__restrict__ bias, const __half *__restrict__ residual, const __half *__restrict__ bias, const __half *__restrict__ residual,
const int seed, const int hidden_size) { const int seed, const int hidden_size) {
const __half scale = 1. / (1. - ratio); const __half scale = 1. / (1. - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in); const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out); float4 *outs_float4 = reinterpret_cast<float4 *>(out);
const float4 *residual4 = reinterpret_cast<const float4 *>(residual); const float4 *residual4 = reinterpret_cast<const float4 *>(residual);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias); const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask); uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8]; uint8_t m[8];
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
m[0] = static_cast<uint8_t>(rand.x > ratio); m[0] = static_cast<uint8_t>(rand.x > ratio);
m[1] = static_cast<uint8_t>(rand.y > ratio); m[1] = static_cast<uint8_t>(rand.y > ratio);
m[2] = static_cast<uint8_t>(rand.z > ratio); m[2] = static_cast<uint8_t>(rand.z > ratio);
m[3] = static_cast<uint8_t>(rand.w > ratio); m[3] = static_cast<uint8_t>(rand.w > ratio);
rand = curand_uniform4(&state); rand = curand_uniform4(&state);
m[4] = static_cast<uint8_t>(rand.x > ratio); m[4] = static_cast<uint8_t>(rand.x > ratio);
m[5] = static_cast<uint8_t>(rand.y > ratio); m[5] = static_cast<uint8_t>(rand.y > ratio);
m[6] = static_cast<uint8_t>(rand.z > ratio); m[6] = static_cast<uint8_t>(rand.z > ratio);
m[7] = static_cast<uint8_t>(rand.w > ratio); m[7] = static_cast<uint8_t>(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m); uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = m8[0]; mask8[i] = m8[0];
int bias_i = i % (hidden_size >> 3); int bias_i = i % (hidden_size >> 3);
float4 val_float4 = vals_float4[i]; float4 val_float4 = vals_float4[i];
const float4 b4 = __ldg(&bias4[bias_i]); const float4 b4 = __ldg(&bias4[bias_i]);
const float4 res4 = residual4[i]; const float4 res4 = residual4[i];
float4 out_float4; float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4); const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4);
const __half2 *res_half2 = reinterpret_cast<const __half2 *>(&res4); const __half2 *res_half2 = reinterpret_cast<const __half2 *>(&res4);
__half2 scale_mask_1 = __half2 scale_mask_1 =
__halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
__half2 scale_mask_2 = __half2 scale_mask_2 =
__halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
__half2 scale_mask_3 = __half2 scale_mask_3 =
__halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
__half2 scale_mask_4 = __half2 scale_mask_4 =
__halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
out_half2[0] = out_half2[0] =
__hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]);
out_half2[1] = out_half2[1] =
__hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]);
out_half2[2] = out_half2[2] =
__hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]);
out_half2[3] = out_half2[3] =
__hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]);
outs_float4[i] = out_float4; outs_float4[i] = out_float4;
} }
template <> template <>
void launch_ls_dropout_res_bias<float>(float *out, const float *vals, void launch_ls_dropout_res_bias<float>(float *out, const float *vals,
uint8_t *mask, const float *bias, uint8_t *mask, const float *bias,
const float *residual, int total_count, const float *residual, int total_count,
int dim, float ratio, int dim, float ratio,
cudaStream_t stream) { cudaStream_t stream) {
int grid_dim = total_count >> 12; int grid_dim = total_count >> 12;
ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>( ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, residual, total_count, ratio, out, vals, mask, bias, residual,
std::chrono::duration_cast<std::chrono::microseconds>( std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count(), .count(),
dim); dim);
} }
template <> template <>
void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals,
uint8_t *mask, const __half *bias, uint8_t *mask, const __half *bias,
const __half *residual, int total_count, const __half *residual, int total_count,
int dim, float ratio, int dim, float ratio,
cudaStream_t stream) { cudaStream_t stream) {
int grid_dim = total_count >> 13; int grid_dim = total_count >> 13;
ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>( ls_dropout_res_bias_kernel<<<grid_dim + 1, 1024, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, residual, total_count, ratio, out, vals, mask, bias, residual,
std::chrono::duration_cast<std::chrono::microseconds>( std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count(), .count(),
dim); dim);
} }
/** /**
* @brief fused bias and dropout backward at the end of Attention and FFN * @brief fused bias and dropout backward at the end of Attention and FFN
* *
* @thread * @thread
* gridDim.x = hidden_size / 8 * gridDim.x = hidden_size / 8
* blockDim.x = 8 * blockDim.x = 8
* blockDim.y = 1024 / 8 = 128 * blockDim.y = 1024 / 8 = 128
* *
* @param row_size batch_size * seq_len * @param row_size batch_size * seq_len
* @param ratio dropout ratio * @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad * @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad * @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad * @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask * @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size * @param hidden_size
* @return void * @return void
*/ */
__global__ void ls_dropout_bias_bwd_kernel( __global__ void ls_dropout_bias_bwd_kernel(
const int row_size, const float ratio, float *__restrict__ in_grad, const int row_size, const float ratio, float *__restrict__ in_grad,
float *__restrict__ bias_grad, const float *__restrict__ out_grad, float *__restrict__ bias_grad, const float *__restrict__ out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) { const uint8_t *__restrict__ mask, const int hidden_size) {
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
// every block generate 8 bias result // every block generate 8 bias result
__shared__ float tile[8][129]; __shared__ float tile[8][129];
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
int stride = hidden_size * 128; int stride = hidden_size * 128;
float local_sum = 0; float local_sum = 0;
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
for (int r = threadIdx.y; r < row_size; r += 128) { for (int r = threadIdx.y; r < row_size; r += 128) {
float val = out_grad[idx]; float val = out_grad[idx];
val *= scale * static_cast<float>(mask[idx]); val *= scale * static_cast<float>(mask[idx]);
local_sum += val; local_sum += val;
in_grad[idx] = val; in_grad[idx] = val;
idx += stride; idx += stride;
} }
tile[threadIdx.x][threadIdx.y] = local_sum; tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads(); __syncthreads();
float sum = 0; float sum = 0;
int tid = threadIdx.y * blockDim.x + threadIdx.x; int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 7; int x = tid >> 7;
int y = tid & (127); int y = tid & (127);
if (y < 32) { if (y < 32) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
sum += tile[x][y + i * 32]; sum += tile[x][y + i * 32];
} }
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum; if (y == 0) tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
bias_grad[pos] = tile[0][threadIdx.x]; bias_grad[pos] = tile[0][threadIdx.x];
} }
} }
__global__ void ls_dropout_bias_bwd_kernel( __global__ void ls_dropout_bias_bwd_kernel(
const int row_size, const float ratio, __half *__restrict__ in_grad, const int row_size, const float ratio, __half *__restrict__ in_grad,
__half *__restrict__ bias_grad, const __half *__restrict__ out_grad, __half *__restrict__ bias_grad, const __half *__restrict__ out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) { const uint8_t *__restrict__ mask, const int hidden_size) {
const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
__shared__ __half2 tile[8][129]; __shared__ __half2 tile[8][129];
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
__half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad); const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
__half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
int stride = hidden_size * 128; int stride = hidden_size * 128;
__half2 local_sum = __float2half2_rn(0.f); __half2 local_sum = __float2half2_rn(0.f);
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
for (int r = threadIdx.y; r < row_size; r += 128) { for (int r = threadIdx.y; r < row_size; r += 128) {
__half2 val = out_grad2[idx]; __half2 val = out_grad2[idx];
__half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
val *= scale * m2; val *= scale * m2;
local_sum += val; local_sum += val;
in_grad2[idx] = val; in_grad2[idx] = val;
idx += stride; idx += stride;
} }
tile[threadIdx.x][threadIdx.y] = local_sum; tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads(); __syncthreads();
__half2 sum = __float2half2_rn(0.f); __half2 sum = __float2half2_rn(0.f);
int tid = threadIdx.y * blockDim.x + threadIdx.x; int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 7; int x = tid >> 7;
int y = tid & (127); int y = tid & (127);
if (y < 32) { if (y < 32) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
sum += tile[x][y + i * 32]; sum += tile[x][y + i * 32];
} }
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum; if (y == 0) tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
bias_grad2[pos] = tile[0][threadIdx.x]; bias_grad2[pos] = tile[0][threadIdx.x];
} }
} }
template <typename T> template <typename T>
void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
const uint8_t *mask, int row_size, int dim, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream) { float ratio, cudaStream_t stream) {
dim3 grid_dim((dim - 1) / 8 + 1); dim3 grid_dim((dim - 1) / 8 + 1);
dim3 block_dim(8, 128); dim3 block_dim(8, 128);
ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>( ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
} }
template <> template <>
void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad,
const __half *out_grad, const uint8_t *mask, const __half *out_grad, const uint8_t *mask,
int row_size, int dim, float ratio, int row_size, int dim, float ratio,
cudaStream_t stream) { cudaStream_t stream) {
dim >>= 1; dim >>= 1;
dim3 grid_dim((dim - 1) / 8 + 1); dim3 grid_dim((dim - 1) / 8 + 1);
dim3 block_dim(8, 128); dim3 block_dim(8, 128);
ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>( ls_dropout_bias_bwd_kernel<<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
} }
template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad,
const float *out_grad, const float *out_grad,
const uint8_t *mask, int row_size, const uint8_t *mask, int row_size,
int dim, float ratio, int dim, float ratio,
cudaStream_t stream); cudaStream_t stream);
/** /**
* @brief fused bias, activation, and dropout at the end of first ffn * @brief fused bias, activation, and dropout at the end of first ffn
* *
* @thread * @thread
* gridDim.x = hidden_size / 8 * gridDim.x = hidden_size / 8
* blockDim.x = 8 * blockDim.x = 8
* blockDim.y = 1024 / 8 = 128 * blockDim.y = 1024 / 8 = 128
* *
* @tparam act_type activation function, like kRelu, kGelu * @tparam act_type activation function, like kRelu, kGelu
* @param total_count total elements * @param total_count total elements
* @param ratio drop ratio * @param ratio drop ratio
* @param out [batch_size, seq_len, hidden_size], float and __half * @param out [batch_size, seq_len, hidden_size], float and __half
* @param in [batch_size, seq_len, hidden_size], float and __half * @param in [batch_size, seq_len, hidden_size], float and __half
* @param mask [batch_size, seq_len, hidden_size], uint8 type * @param mask [batch_size, seq_len, hidden_size], uint8 type
* @param bias [hidden_size], ffn bias * @param bias [hidden_size], ffn bias
* @param seed seed to curand * @param seed seed to curand
* @param hidden_size * @param hidden_size
* @return void * @return void
*/ */
template <ActivationType act_type> template <ActivationType act_type>
__global__ void ls_dropout_act_bias_kernel( __global__ void ls_dropout_act_bias_kernel(
const int total_count, const float ratio, float *__restrict__ out, const int total_count, const float ratio, float *__restrict__ out,
const float *__restrict__ in, uint8_t *__restrict__ mask, const float *__restrict__ in, uint8_t *__restrict__ mask,
const float *__restrict__ bias, const int seed, const int hidden_size) { const float *__restrict__ bias, const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
uint8_t m[4]; uint8_t m[4];
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *data4 = reinterpret_cast<const float4 *>(in); const float4 *data4 = reinterpret_cast<const float4 *>(in);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias); const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask); uint32_t *mask4 = reinterpret_cast<uint32_t *>(mask);
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio); m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio); m[3] = (uint8_t)(rand.w > ratio);
int bias_i = i % (hidden_size >> 2); int bias_i = i % (hidden_size >> 2);
uint32_t *m4 = reinterpret_cast<uint32_t *>(m); uint32_t *m4 = reinterpret_cast<uint32_t *>(m);
mask4[i] = m4[0]; mask4[i] = m4[0];
const float4 input4 = data4[i]; const float4 input4 = data4[i];
const float4 b4 = __ldg(&bias4[bias_i]); const float4 b4 = __ldg(&bias4[bias_i]);
float4 output4; float4 output4;
output4.x = output4.x =
activation_kernel<act_type, float>(input4.x + b4.x) * scale * m[0]; activation_kernel<act_type, float>(input4.x + b4.x) * scale * m[0];
output4.y = output4.y =
activation_kernel<act_type, float>(input4.y + b4.y) * scale * m[1]; activation_kernel<act_type, float>(input4.y + b4.y) * scale * m[1];
output4.z = output4.z =
activation_kernel<act_type, float>(input4.z + b4.z) * scale * m[2]; activation_kernel<act_type, float>(input4.z + b4.z) * scale * m[2];
output4.w = output4.w =
activation_kernel<act_type, float>(input4.w + b4.w) * scale * m[3]; activation_kernel<act_type, float>(input4.w + b4.w) * scale * m[3];
out4[i] = output4; out4[i] = output4;
} }
template <ActivationType act_type> template <ActivationType act_type>
__global__ void ls_dropout_act_bias_kernel( __global__ void ls_dropout_act_bias_kernel(
const int total_count, const float ratio, __half *__restrict__ out, const int total_count, const float ratio, __half *__restrict__ out,
const __half *__restrict__ in, uint8_t *__restrict__ mask, const __half *__restrict__ in, uint8_t *__restrict__ mask,
const __half *__restrict__ bias, const int seed, const int hidden_size) { const __half *__restrict__ bias, const int seed, const int hidden_size) {
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in); const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
float4 *outs_float4 = reinterpret_cast<float4 *>(out); float4 *outs_float4 = reinterpret_cast<float4 *>(out);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias); const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask); uint64_t *mask8 = reinterpret_cast<uint64_t *>(mask);
uint8_t m[8]; uint8_t m[8];
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio); m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio); m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio); m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio); m[3] = (uint8_t)(rand.w > ratio);
rand = curand_uniform4(&state); rand = curand_uniform4(&state);
m[4] = (uint8_t)(rand.x > ratio); m[4] = (uint8_t)(rand.x > ratio);
m[5] = (uint8_t)(rand.y > ratio); m[5] = (uint8_t)(rand.y > ratio);
m[6] = (uint8_t)(rand.z > ratio); m[6] = (uint8_t)(rand.z > ratio);
m[7] = (uint8_t)(rand.w > ratio); m[7] = (uint8_t)(rand.w > ratio);
uint64_t *m8 = reinterpret_cast<uint64_t *>(m); uint64_t *m8 = reinterpret_cast<uint64_t *>(m);
mask8[i] = *m8; mask8[i] = *m8;
int bias_i = i % (hidden_size >> 3); int bias_i = i % (hidden_size >> 3);
float4 val_float4 = vals_float4[i]; float4 val_float4 = vals_float4[i];
const float4 b4 = __ldg(&bias4[bias_i]); const float4 b4 = __ldg(&bias4[bias_i]);
float4 out_float4; float4 out_float4;
__half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
__half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4); const __half2 *b_half2 = reinterpret_cast<const __half2 *>(&b4);
__half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
__half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
__half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
__half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
out_half2[0] = __hmul2( out_half2[0] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[0], b_half2[0])), activation_kernel<act_type, __half2>(__hadd2(val_half2[0], b_half2[0])),
scale_mask_1); scale_mask_1);
out_half2[1] = __hmul2( out_half2[1] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[1], b_half2[1])), activation_kernel<act_type, __half2>(__hadd2(val_half2[1], b_half2[1])),
scale_mask_2); scale_mask_2);
out_half2[2] = __hmul2( out_half2[2] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[2], b_half2[2])), activation_kernel<act_type, __half2>(__hadd2(val_half2[2], b_half2[2])),
scale_mask_3); scale_mask_3);
out_half2[3] = __hmul2( out_half2[3] = __hmul2(
activation_kernel<act_type, __half2>(__hadd2(val_half2[3], b_half2[3])), activation_kernel<act_type, __half2>(__hadd2(val_half2[3], b_half2[3])),
scale_mask_4); scale_mask_4);
outs_float4[i] = out_float4; outs_float4[i] = out_float4;
} }
template <> template <>
void launch_ls_dropout_act_bias<ActivationType::kGelu, float>( void launch_ls_dropout_act_bias<ActivationType::kGelu, float>(
float *out, const float *vals, uint8_t *mask, const float *bias, float *out, const float *vals, uint8_t *mask, const float *bias,
int total_count, int dim, float ratio, cudaStream_t stream) { int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 10; int grid_dim = total_count >> 10;
ls_dropout_act_bias_kernel<ActivationType::kGelu> ls_dropout_act_bias_kernel<ActivationType::kGelu>
<<<grid_dim + 1, 256, 0, stream>>>( <<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>( std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count(), .count(),
dim); dim);
} }
template <> template <>
void launch_ls_dropout_act_bias<ActivationType::kGelu, __half>( void launch_ls_dropout_act_bias<ActivationType::kGelu, __half>(
__half *out, const __half *vals, uint8_t *mask, const __half *bias, __half *out, const __half *vals, uint8_t *mask, const __half *bias,
int total_count, int dim, float ratio, cudaStream_t stream) { int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 11; int grid_dim = total_count >> 11;
ls_dropout_act_bias_kernel<ActivationType::kGelu> ls_dropout_act_bias_kernel<ActivationType::kGelu>
<<<grid_dim + 1, 256, 0, stream>>>( <<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>( std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count(), .count(),
dim); dim);
} }
template <> template <>
void launch_ls_dropout_act_bias<ActivationType::kRelu, float>( void launch_ls_dropout_act_bias<ActivationType::kRelu, float>(
float *out, const float *vals, uint8_t *mask, const float *bias, float *out, const float *vals, uint8_t *mask, const float *bias,
int total_count, int dim, float ratio, cudaStream_t stream) { int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 10; int grid_dim = total_count >> 10;
ls_dropout_act_bias_kernel<ActivationType::kRelu> ls_dropout_act_bias_kernel<ActivationType::kRelu>
<<<grid_dim + 1, 256, 0, stream>>>( <<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>( std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count(), .count(),
dim); dim);
} }
template <> template <>
void launch_ls_dropout_act_bias<ActivationType::kRelu, __half>( void launch_ls_dropout_act_bias<ActivationType::kRelu, __half>(
__half *out, const __half *vals, uint8_t *mask, const __half *bias, __half *out, const __half *vals, uint8_t *mask, const __half *bias,
int total_count, int dim, float ratio, cudaStream_t stream) { int total_count, int dim, float ratio, cudaStream_t stream) {
int grid_dim = total_count >> 11; int grid_dim = total_count >> 11;
ls_dropout_act_bias_kernel<ActivationType::kRelu> ls_dropout_act_bias_kernel<ActivationType::kRelu>
<<<grid_dim + 1, 256, 0, stream>>>( <<<grid_dim + 1, 256, 0, stream>>>(
total_count, ratio, out, vals, mask, bias, total_count, ratio, out, vals, mask, bias,
std::chrono::duration_cast<std::chrono::microseconds>( std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count(), .count(),
dim); dim);
} }
/** /**
* @brief fused bias, activation, and dropout backward * @brief fused bias, activation, and dropout backward
* *
* @thread * @thread
* gridDim.x = total_count / 1024 * gridDim.x = total_count / 1024
* blockDim.x = 1024 * blockDim.x = 1024
* *
* @tparam act_type kRelu * @tparam act_type kRelu
* @param row_size batch_size * seq_len * @param row_size batch_size * seq_len
* @param ratio dropout ratio * @param ratio dropout ratio
* @param in_grad [batch_size, seq_len, hidden_size], input grad * @param in_grad [batch_size, seq_len, hidden_size], input grad
* @param bias_grad [hidden_size], bias grad * @param bias_grad [hidden_size], bias grad
* @param out_grad [batch_size, seq_len, hidden_size], output grad * @param out_grad [batch_size, seq_len, hidden_size], output grad
* @param mask [batch_size, seq_len, hidden_size], dropout mask * @param mask [batch_size, seq_len, hidden_size], dropout mask
* @param hidden_size * @param hidden_size
* @return void * @return void
*/ */
template <ActivationType act_type, typename T> template <ActivationType act_type, typename T>
__global__ void ls_dropout_act_bias_bwd_kernel( __global__ void ls_dropout_act_bias_bwd_kernel(
const int row_size, const float ratio, T *in_grad, const int row_size, const float ratio, T *in_grad,
T *__restrict__ bias_grad, const T *__restrict__ input, T *__restrict__ bias_grad, const T *__restrict__ input,
const T *__restrict__ bias, const T *out_grad, const T *__restrict__ bias, const T *out_grad,
const uint8_t *__restrict__ mask, const int hidden_size) { const uint8_t *__restrict__ mask, const int hidden_size) {
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
__shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; __shared__ float tile[WARP_SIZE][WARP_SIZE + 1];
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int stride = hidden_size * WARP_SIZE; int stride = hidden_size * WARP_SIZE;
float local_sum = 0; float local_sum = 0;
int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
if (col_idx < hidden_size) { if (col_idx < hidden_size) {
for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
float val = out_grad[idx]; float val = out_grad[idx];
float in = input[idx]; float in = input[idx];
float b = bias[idx % hidden_size]; float b = bias[idx % hidden_size];
val = activation_bwd_kernel<act_type, float>( val = activation_bwd_kernel<act_type, float>(
val * scale * static_cast<float>(mask[idx]), in + b); val * scale * static_cast<float>(mask[idx]), in + b);
local_sum += val; local_sum += val;
in_grad[idx] = val; in_grad[idx] = val;
idx += stride; idx += stride;
} }
} }
tile[threadIdx.x][threadIdx.y] = local_sum; tile[threadIdx.x][threadIdx.y] = local_sum;
__syncthreads(); __syncthreads();
float sum = tile[threadIdx.y][threadIdx.x]; float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
bias_grad[pos] = tile[0][threadIdx.x]; bias_grad[pos] = tile[0][threadIdx.x];
} }
} }
// @brief fused bias, activation, and dropout backward // @brief fused bias, activation, and dropout backward
// It is deprecated for precision reason. Keep it for future optimization. // It is deprecated for precision reason. Keep it for future optimization.
// //
// template <ActivationType act_type> // template <ActivationType act_type>
// __global__ void ls_dropout_act_bias_bwd_kernel( // __global__ void ls_dropout_act_bias_bwd_kernel(
// const int row_size, const float ratio, __half * in_grad, // const int row_size, const float ratio, __half * in_grad,
// __half *__restrict__ bias_grad, const __half *__restrict__ input, const // __half *__restrict__ bias_grad, const __half *__restrict__ input, const
// __half *__restrict__ bias, const __half * out_grad, const uint8_t // __half *__restrict__ bias, const __half * out_grad, const uint8_t
// *__restrict__ mask, const int hidden_size) { // *__restrict__ mask, const int hidden_size) {
// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); // const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; // __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1];
// cg::thread_block b = cg::this_thread_block(); // cg::thread_block b = cg::this_thread_block();
// cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); // cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); // __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); // __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
// const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad); // const __half2 *out_grad2 = reinterpret_cast<const __half2 *>(out_grad);
// const __half2 *input2 = reinterpret_cast<const __half2 *>(input); // const __half2 *input2 = reinterpret_cast<const __half2 *>(input);
// const __half2 *bias2 = reinterpret_cast<const __half2 *>(bias); // const __half2 *bias2 = reinterpret_cast<const __half2 *>(bias);
// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); // int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// int stride = hidden_size * WARP_SIZE; // int stride = hidden_size * WARP_SIZE;
// __half2 local_sum = __float2half2_rn(0.f); // __half2 local_sum = __float2half2_rn(0.f);
// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); // int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
// if (col_idx < hidden_size) { // if (col_idx < hidden_size) {
// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { // for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
// __half2 val = out_grad2[idx]; // __half2 val = out_grad2[idx];
// __half2 in2 = input2[idx]; // __half2 in2 = input2[idx];
// __half2 b2 = bias2[idx % hidden_size ]; // __half2 b2 = bias2[idx % hidden_size ];
// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); // __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
// val = activation_bwd_kernel<ActivationType::kRelu, __half2>(val * scale // val = activation_bwd_kernel<ActivationType::kRelu, __half2>(val * scale
// * // *
// m2, // m2,
// in2+b2); // in2+b2);
// local_sum += val; // local_sum += val;
// in_grad2[idx] = val; // in_grad2[idx] = val;
// idx += stride; // idx += stride;
// } // }
// } // }
// tile[threadIdx.x][threadIdx.y] = local_sum; // tile[threadIdx.x][threadIdx.y] = local_sum;
// __syncthreads(); // __syncthreads();
// __half2 sum = tile[threadIdx.y][threadIdx.x]; // __half2 sum = tile[threadIdx.y][threadIdx.x];
// __syncthreads(); // __syncthreads();
// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); // for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; // if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
// __syncthreads(); // __syncthreads();
// if (threadIdx.y == 0) { // if (threadIdx.y == 0) {
// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); // int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
// bias_grad2[pos] = tile[0][threadIdx.x]; // bias_grad2[pos] = tile[0][threadIdx.x];
// } // }
// } // }
template <ActivationType act_type, typename T> template <ActivationType act_type, typename T>
void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
const T *bias, const T *out_grad, const T *bias, const T *out_grad,
const uint8_t *mask, int row_size, int dim, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream) { float ratio, cudaStream_t stream) {
dim3 grid_dim((dim - 1) / WARP_SIZE + 1); dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE); dim3 block_dim(WARP_SIZE, WARP_SIZE);
ls_dropout_act_bias_bwd_kernel<act_type><<<grid_dim, block_dim, 0, stream>>>( ls_dropout_act_bias_bwd_kernel<act_type><<<grid_dim, block_dim, 0, stream>>>(
row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim);
} }
// template <> // template <>
// void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>( // void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
// __half *in_grad, __half *bias_grad,const __half *input, const __half // __half *in_grad, __half *bias_grad,const __half *input, const __half
// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int // *bias, const __half *out_grad, const uint8_t *mask, int row_size, int
// dim, float ratio, cudaStream_t stream) { // dim, float ratio, cudaStream_t stream) {
// dim >>= 1; // dim >>= 1;
// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); // dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
// dim3 block_dim(WARP_SIZE, WARP_SIZE); // dim3 block_dim(WARP_SIZE, WARP_SIZE);
// ls_dropout_act_bias_bwd_kernel<ActivationType::kRelu> // ls_dropout_act_bias_bwd_kernel<ActivationType::kRelu>
// <<<grid_dim, block_dim, 0, stream>>>(row_size, ratio, in_grad, // <<<grid_dim, block_dim, 0, stream>>>(row_size, ratio, in_grad,
// bias_grad, // bias_grad,
// input, bias,out_grad, mask, dim); // input, bias,out_grad, mask, dim);
// } // }
template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, float>( template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, float>(
float *in_grad, float *bias_grad, const float *input, const float *bias, float *in_grad, float *bias_grad, const float *input, const float *bias,
const float *out_grad, const uint8_t *mask, int row_size, int dim, const float *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream); float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>( template void launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, __half>(
__half *in_grad, __half *bias_grad, const __half *input, const __half *bias, __half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
const __half *out_grad, const uint8_t *mask, int row_size, int dim, const __half *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream); float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, float>( template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, float>(
float *in_grad, float *bias_grad, const float *input, const float *bias, float *in_grad, float *bias_grad, const float *input, const float *bias,
const float *out_grad, const uint8_t *mask, int row_size, int dim, const float *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream); float ratio, cudaStream_t stream);
template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, __half>( template void launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, __half>(
__half *in_grad, __half *bias_grad, const __half *input, const __half *bias, __half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
const __half *out_grad, const uint8_t *mask, int row_size, int dim, const __half *out_grad, const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream); float ratio, cudaStream_t stream);
#include <cooperative_groups.h> #include <cooperative_groups.h>
#include "kernels.h" #include "kernels.h"
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
/** /**
@brief: fuse_transpose_bias @brief: fuse_transpose_bias
Calculate the sum of elements in each column of the matrix. Calculate the sum of elements in each column of the matrix.
@thread @thread
gridDim.x = ceil(cols / WARP_SIZE) gridDim.x = ceil(cols / WARP_SIZE)
blockDim.x = WARP_SIZE blockDim.x = WARP_SIZE
blockDim.y = WARP_SIZE blockDim.y = WARP_SIZE
@param @param
inp: [rows, cols] inp: [rows, cols]
out: [cols] out: [cols]
rows: the number of rows in the matrix rows: the number of rows in the matrix
cols: the number of cols in the matrix cols: the number of cols in the matrix
*/ */
template <typename T> template <typename T>
__global__ void column_sum_reduce(const T *__restrict__ inp, __global__ void column_sum_reduce(const T *__restrict__ inp,
T *__restrict__ out, int rows, int cols) { T *__restrict__ out, int rows, int cols) {
__shared__ float tile[WARP_SIZE][WARP_SIZE]; __shared__ float tile[WARP_SIZE][WARP_SIZE];
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int y_stride = cols * WARP_SIZE; int y_stride = cols * WARP_SIZE;
float localSum = 0; float localSum = 0;
// Loop across matrix row // Loop across matrix row
// TODO: optimize to log complexity // TODO: optimize to log complexity
if (idx < cols) { if (idx < cols) {
int offset = flat_2dim(threadIdx.y, idx, cols); int offset = flat_2dim(threadIdx.y, idx, cols);
for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
localSum += (float)inp[offset]; localSum += (float)inp[offset];
offset += y_stride; offset += y_stride;
} }
} }
// The sum of a row in tile is equal to the sum of a col in original matrix // The sum of a row in tile is equal to the sum of a col in original matrix
tile[threadIdx.x][threadIdx.y] = localSum; tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads(); __syncthreads();
// Sum the shared buffer. // Sum the shared buffer.
// The change of threadIdx.x is continuous // The change of threadIdx.x is continuous
float sum = tile[threadIdx.y][threadIdx.x]; float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads(); __syncthreads();
// Calculate the sum of a row in tile // Calculate the sum of a row in tile
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
if (pos < cols) out[pos] = sum; if (pos < cols) out[pos] = sum;
} }
} }
// [r, c] -> [c] // [r, c] -> [c]
template <> template <>
void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out, void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out,
int rows, int cols, int rows, int cols,
cudaStream_t stream) { cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1); dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE); dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<float> column_sum_reduce<float>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols); <<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
} }
template <> template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
int rows, int cols, int rows, int cols,
cudaStream_t stream) { cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1); dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE); dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<__half> column_sum_reduce<__half>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols); <<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
} }
/** /**
@brief: fused_add2 @brief: fused_add2
Add two matrix inp1 and inp2 to out. Add two matrix inp1 and inp2 to out.
@thread @thread
gridDim.x = batch_size * seq_len gridDim.x = batch_size * seq_len
blockDim.x = min(hidden_dim, MAX_THREADS) blockDim.x = min(hidden_dim, MAX_THREADS)
@param @param
inp1: [batch_size, seq_len, hidden_dim] inp1: [batch_size, seq_len, hidden_dim]
inp2: [batch_size, seq_len, hidden_dim] inp2: [batch_size, seq_len, hidden_dim]
out: [batch_size, seq_len, hidden_dim] out: [batch_size, seq_len, hidden_dim]
batch_size: the size of the current batch batch_size: the size of the current batch
seq_len: the sequence length of the current batch seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor hidden_dim: dim of the hidden tensor
*/ */
template <typename T> template <typename T>
__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, __global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
int hidden_dim); int hidden_dim);
template <> template <>
__global__ void fused_add2_kernel<float>(float *out, const float *inp1, __global__ void fused_add2_kernel<float>(float *out, const float *inp1,
const float *inp2, int hidden_dim) { const float *inp2, int hidden_dim) {
int row_id = blockIdx.x; int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim); int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1); const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2); const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out); float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1; float4 vinp1;
float4 vinp2; float4 vinp2;
float4 val; float4 val;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i]; vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i]; vinp2 = inp2_4[offset + i];
val.x = vinp1.x + vinp2.x; val.x = vinp1.x + vinp2.x;
val.y = vinp1.y + vinp2.y; val.y = vinp1.y + vinp2.y;
val.z = vinp1.z + vinp2.z; val.z = vinp1.z + vinp2.z;
val.w = vinp1.w + vinp2.w; val.w = vinp1.w + vinp2.w;
out_4[offset + i] = val; out_4[offset + i] = val;
} }
} }
template <> template <>
__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, __global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
const __half *inp2, int hidden_dim) { const __half *inp2, int hidden_dim) {
int row_id = blockIdx.x; int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim); int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1); const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2); const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out); float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1; float4 vinp1;
float4 vinp2; float4 vinp2;
float4 val; float4 val;
__half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1);
__half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2);
__half2 *h2_val = reinterpret_cast<__half2 *>(&val); __half2 *h2_val = reinterpret_cast<__half2 *>(&val);
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i]; vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i]; vinp2 = inp2_4[offset + i];
h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]);
h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]);
h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]);
h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]);
out_4[offset + i] = val; out_4[offset + i] = val;
} }
} }
//[b, s, h] -> [b, s, h] //[b, s, h] -> [b, s, h]
template <> template <>
void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2, void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2,
int batch_size, int seq_len, int hidden_dim, int batch_size, int seq_len, int hidden_dim,
cudaStream_t &stream) { cudaStream_t &stream) {
hidden_dim >>= 2; hidden_dim >>= 2;
dim3 grid_dim(batch_size * seq_len); dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS)); dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2, fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim); hidden_dim);
} }
template <> template <>
void launch_fused_add2<__half>(__half *out, const __half *inp1, void launch_fused_add2<__half>(__half *out, const __half *inp1,
const __half *inp2, int batch_size, int seq_len, const __half *inp2, int batch_size, int seq_len,
int hidden_dim, cudaStream_t &stream) { int hidden_dim, cudaStream_t &stream) {
hidden_dim >>= 3; hidden_dim >>= 3;
dim3 grid_dim(batch_size * seq_len); dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS)); dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2, fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim); hidden_dim);
} }
template <typename T> template <typename T>
__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, __global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output,
int sz0, int sz2, int sz1_1, int sz1_2) { int sz0, int sz2, int sz1_1, int sz1_2) {
int nele = sz0 * sz2 * (sz1_1 + sz1_2); int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x);
if (idx >= nele) { if (idx >= nele) {
return; return;
} }
float4 *dst_ptr = (float4 *)output + idx; float4 *dst_ptr = (float4 *)output + idx;
int idx2 = idx % sz2; int idx2 = idx % sz2;
idx = idx / sz2; idx = idx / sz2;
int idx1 = idx % (sz1_1 + sz1_2); int idx1 = idx % (sz1_1 + sz1_2);
int idx0 = idx / (sz1_1 + sz1_2); int idx0 = idx / (sz1_1 + sz1_2);
float4 *src_ptr = nullptr; float4 *src_ptr = nullptr;
int sz1 = 0; int sz1 = 0;
if (idx1 < sz1_1) { if (idx1 < sz1_1) {
sz1 = sz1_1; sz1 = sz1_1;
src_ptr = (float4 *)inp1; src_ptr = (float4 *)inp1;
} else { } else {
idx1 -= sz1_1; idx1 -= sz1_1;
sz1 = sz1_2; sz1 = sz1_2;
src_ptr = (float4 *)inp2; src_ptr = (float4 *)inp2;
} }
src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2);
dst_ptr[0] = src_ptr[0]; dst_ptr[0] = src_ptr[0];
} }
template <> template <>
void launch_concat3_dim1<float>(const float *inp1, const float *inp2, void launch_concat3_dim1<float>(const float *inp1, const float *inp2,
float *output, int sz0, int sz2, int sz1_1, float *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) { int sz1_2, cudaStream_t stream) {
sz2 >>= 2; sz2 >>= 2;
int nele = sz0 * sz2 * (sz1_1 + sz1_2); int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>( kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
} }
template <> template <>
void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2,
__half *output, int sz0, int sz2, int sz1_1, __half *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) { int sz1_2, cudaStream_t stream) {
sz2 >>= 3; sz2 >>= 3;
int nele = sz0 * sz2 * (sz1_1 + sz1_2); int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>( kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
} }
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include "kernels.h" #include "kernels.h"
template <typename T> template <typename T>
class Dropout { class Dropout {
public: public:
struct Config { struct Config {
float ratio; float ratio;
bool training; bool training;
Config(float r) : ratio(r), training(true) {} Config(float r) : ratio(r), training(true) {}
float RATIO() const { return training ? ratio : 0.0; } float RATIO() const { return training ? ratio : 0.0; }
}; };
Dropout(const Config &config, size_t max_ele_num) Dropout(const Config &config, size_t max_ele_num)
: _config(config), _mask(nullptr) { : _config(config), _mask(nullptr) {
_mask = cuda_malloc<uint8_t>(max_ele_num); _mask = cuda_malloc<uint8_t>(max_ele_num);
} }
virtual ~Dropout() { cuda_free(_mask); } virtual ~Dropout() { cuda_free(_mask); }
// after attention softmax // after attention softmax
void dropout(T *output, const T *input, int count, cudaStream_t stream, void dropout(T *output, const T *input, int count, cudaStream_t stream,
bool bwd = false) { bool bwd = false) {
launch_ls_dropout<T>(output, input, _mask, count, _config.RATIO(), stream, launch_ls_dropout<T>(output, input, _mask, count, _config.RATIO(), stream,
bwd); bwd);
} }
void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { void d_dropout(T *d_inp_out, int count, cudaStream_t stream) {
launch_ls_dropout<T>(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), launch_ls_dropout<T>(d_inp_out, d_inp_out, _mask, count, _config.RATIO(),
stream, true); stream, true);
} }
// transformer layer's postprocessing dropout, after attn or ffn module, // transformer layer's postprocessing dropout, after attn or ffn module,
// before residual add. // before residual add.
void bias_dropout_residual(T *output, const T *input, const T *residual, void bias_dropout_residual(T *output, const T *input, const T *residual,
const T *bias, int rows, int cols, const T *bias, int rows, int cols,
cudaStream_t stream) { cudaStream_t stream) {
launch_ls_dropout_res_bias<T>(output, input, _mask, bias, residual, launch_ls_dropout_res_bias<T>(output, input, _mask, bias, residual,
rows * cols, cols, _config.RATIO(), stream); rows * cols, cols, _config.RATIO(), stream);
} }
void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output,
int rows, int cols, cudaStream_t stream) { int rows, int cols, cudaStream_t stream) {
launch_ls_dropout_bias_bwd<T>(d_input, d_bias, d_output, _mask, rows, cols, launch_ls_dropout_bias_bwd<T>(d_input, d_bias, d_output, _mask, rows, cols,
_config.RATIO(), stream); _config.RATIO(), stream);
} }
// dropout inside ffn. // dropout inside ffn.
void bias_act_dropout(T *output, const T *input, const T *bias, int rows, void bias_act_dropout(T *output, const T *input, const T *bias, int rows,
int cols, std::string activation_fn, int cols, std::string activation_fn,
cudaStream_t stream) { cudaStream_t stream) {
if (activation_fn == "relu") { if (activation_fn == "relu") {
launch_ls_dropout_act_bias<ActivationType::kRelu, T>( launch_ls_dropout_act_bias<ActivationType::kRelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(), output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream); stream);
} else if (activation_fn == "gelu") { } else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias<ActivationType::kGelu, T>( launch_ls_dropout_act_bias<ActivationType::kGelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(), output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream); stream);
} else { } else {
throw std::runtime_error("not supported activation: " + activation_fn); throw std::runtime_error("not supported activation: " + activation_fn);
} }
} }
void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input,
const T *bias, int rows, int cols, const T *bias, int rows, int cols,
std::string activation_fn, cudaStream_t stream) { std::string activation_fn, cudaStream_t stream) {
if (activation_fn == "relu") { if (activation_fn == "relu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, T>( launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream); _config.RATIO(), stream);
} else if (activation_fn == "gelu") { } else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, T>( launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream); _config.RATIO(), stream);
} else { } else {
throw std::runtime_error("not supported activation: " + activation_fn); throw std::runtime_error("not supported activation: " + activation_fn);
} }
} }
bool HasDropout() const { return _config.RATIO() > 0.0; } bool HasDropout() const { return _config.RATIO() > 0.0; }
void SetTrainingMode(bool training) { _config.training = training; } void SetTrainingMode(bool training) { _config.training = training; }
private: private:
uint8_t *_mask; uint8_t *_mask;
Config _config; Config _config;
}; };
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <curand_kernel.h> #include <curand_kernel.h>
#include <stdexcept>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <stdexcept>
#define MAX_THREADS 1024 #define MAX_THREADS 1024
#define WARP_SIZE 32 #define WARP_SIZE 32
...@@ -132,8 +133,9 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, ...@@ -132,8 +133,9 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
} }
/* Convert 4-dim tensor index into vector index */ /* Convert 4-dim tensor index into vector index */
__forceinline__ __host__ __device__ int __forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3,
flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) { int id4, int dim2, int dim3,
int dim4) {
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
int res = id4; int res = id4;
...@@ -201,9 +203,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, ...@@ -201,9 +203,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
} }
/* Convert vector index to 6-dim tensor index */ /* Convert vector index to 6-dim tensor index */
__forceinline__ __host__ __device__ void __forceinline__ __host__ __device__ void decompose_6dim(
decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0,
int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) { int *id1, int *id2, int *id3, int *id4, int *id5) {
*id5 = src % dim5; *id5 = src % dim5;
src /= dim5; src /= dim5;
...@@ -221,9 +223,11 @@ decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, ...@@ -221,9 +223,11 @@ decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5,
} }
/* Convert vector index to 5-dim tensor index */ /* Convert vector index to 5-dim tensor index */
__forceinline__ __host__ __device__ void __forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1,
decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0, int dim2, int dim3,
int *id1, int *id2, int *id3, int *id4) { int dim4, int *id0,
int *id1, int *id2,
int *id3, int *id4) {
*id4 = src % dim4; *id4 = src % dim4;
src /= dim4; src /= dim4;
...@@ -253,8 +257,9 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, ...@@ -253,8 +257,9 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
} }
/* Convert vector index to 3-dim tensor index */ /* Convert vector index to 3-dim tensor index */
__forceinline__ __host__ __device__ void __forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1,
decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { int dim2, int *id0,
int *id1, int *id2) {
*id2 = src % dim2; *id2 = src % dim2;
src /= dim2; src /= dim2;
......
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <fstream> #include <fstream>
#include "kernels.h" #include "kernels.h"
using namespace std; using namespace std;
template <typename T> class Normalize_Layer { template <typename T>
public: class Normalize_Layer {
struct Config { public:
uint32_t hidden_dim; struct Config {
bool use_mean; uint32_t hidden_dim;
Config(uint32_t hidden_dim, bool use_mean = false) bool use_mean;
: hidden_dim(hidden_dim), use_mean(use_mean) {} Config(uint32_t hidden_dim, bool use_mean = false)
}; : hidden_dim(hidden_dim), use_mean(use_mean) {}
};
Normalize_Layer(Config config, size_t max_rows)
: config_(config), vars_(nullptr), means_(nullptr) { Normalize_Layer(Config config, size_t max_rows)
vars_ = cuda_malloc<T>(max_rows); : config_(config), vars_(nullptr), means_(nullptr) {
if (config_.use_mean) { vars_ = cuda_malloc<T>(max_rows);
means_ = cuda_malloc<T>(max_rows); if (config_.use_mean) {
} means_ = cuda_malloc<T>(max_rows);
} }
}
~Normalize_Layer() {
cuda_free(vars_); ~Normalize_Layer() {
cuda_free(means_); cuda_free(vars_);
} cuda_free(means_);
}
void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta,
int batch_size, cudaStream_t stream) { void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta,
launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, int batch_size, cudaStream_t stream) {
config_.hidden_dim, stream); launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size,
} config_.hidden_dim, stream);
}
/*
residual_grad, inp_or_out, betta should be treated carefully. /*
inp_or_out = input if use_mean else output residual_grad, inp_or_out, betta should be treated carefully.
residual_grad, betta can be nullptr. inp_or_out = input if use_mean else output
residual_grad will be added to dinp if it is not nullptr residual_grad, betta can be nullptr.
which is useful in transformer layer when pre-ln residual_grad will be added to dinp if it is not nullptr
betta are only used to compute xhat, which is useful in transformer layer when pre-ln
(use_mean == false) ^ (betta == nullptr) should be true betta are only used to compute xhat,
*/ (use_mean == false) ^ (betta == nullptr) should be true
void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, */
const T *residual_grad, const T *inp_or_out, const T *gamma, void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
const T *betta, int batch_size, cudaStream_t stream[2]) { const T *residual_grad, const T *inp_or_out, const T *gamma,
launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, const T *betta, int batch_size, cudaStream_t stream[2]) {
inp_or_out, gamma, betta, vars_, means_, batch_size, launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad,
config_.hidden_dim, stream); inp_or_out, gamma, betta, vars_, means_, batch_size,
} config_.hidden_dim, stream);
}
inline bool use_mean() const { return config_.use_mean; }
inline bool use_mean() const { return config_.use_mean; }
private:
Config config_; private:
T *vars_; Config config_;
T *means_; T *vars_;
}; T *means_;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment