Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -20,18 +20,18 @@ class Unpad(torch.autograd.Function): ...@@ -20,18 +20,18 @@ class Unpad(torch.autograd.Function):
# [b, s, ...] # [b, s, ...]
assert tensor.ndim >= 3 assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0] ctx.bsz = tensor.shape[0]
out = rearrange(tensor, 'b s ... -> (b s) ...') out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape ctx.shape = out.shape
# [ntokens, ...] # [ntokens, ...]
return out[indices] return out[indices]
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indices, = ctx.saved_tensors (indices,) = ctx.saved_tensors
# [ntokens, ...] # [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output grad[indices] = grad_output
grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz) grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...] # [b, s, ...]
return grad, None return grad, None
...@@ -54,7 +54,7 @@ class Repad(torch.autograd.Function): ...@@ -54,7 +54,7 @@ class Repad(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
indices, = ctx.saved_tensors (indices,) = ctx.saved_tensors
# [b*s, ...] # [b*s, ...]
grad = grad_output[indices] grad = grad_output[indices]
# [ntokens, ...] # [ntokens, ...]
......
...@@ -36,34 +36,64 @@ colossal_multihead_attention = None ...@@ -36,34 +36,64 @@ colossal_multihead_attention = None
@dataclass @dataclass
class Config: class Config:
max_batch_tokens: int # max batch token numbers max_batch_tokens: int # max batch token numbers
max_seq_len: int # max sequence length max_seq_len: int # max sequence length
hidden_size: int # size of transformer hidden layers hidden_size: int # size of transformer hidden layers
nhead: int # number of heads in attention nhead: int # number of heads in attention
attn_prob_dropout_ratio: float # attention score dropout ratio attn_prob_dropout_ratio: float # attention score dropout ratio
hidden_dropout_ratio: float # dropout ration before residual hidden_dropout_ratio: float # dropout ration before residual
norm_first: bool # norm_first norm_first: bool # norm_first
fp16: bool # fp16 precision fp16: bool # fp16 precision
class MultiHeadAttention1DFunc(Function): class MultiHeadAttention1DFunc(Function):
@staticmethod @staticmethod
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, def forward(
norm_bias, config): ctx,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config,
):
cuda_module = colossal_multihead_attention cuda_module = colossal_multihead_attention
forward_func = (cuda_module.multihead_attention_fw_fp16 forward_func = (
if config.fp16 else cuda_module.multihead_attention_fw_fp32) cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32
)
if config.fp16: if config.fp16:
input = input.to(torch.half) input = input.to(torch.half)
input_mask = input_mask.to(torch.half) input_mask = input_mask.to(torch.half)
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, (output,) = forward_func(
out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first) config.layer_id,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config.training,
config.norm_first,
)
if config.is_grad_enabled and config.training: if config.is_grad_enabled and config.training:
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, ctx.save_for_backward(
out_proj_bias, norm_weight, norm_bias) output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
ctx.config = config ctx.config = config
return output return output
...@@ -72,11 +102,21 @@ class MultiHeadAttention1DFunc(Function): ...@@ -72,11 +102,21 @@ class MultiHeadAttention1DFunc(Function):
assert ctx.config.training assert ctx.config.training
cuda_module = colossal_multihead_attention cuda_module = colossal_multihead_attention
backward_func = (cuda_module.multihead_attention_bw_fp16 backward_func = (
if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32) cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32
)
output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \ (
out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
) = ctx.saved_tensors
grad_input = None grad_input = None
grad_in_proj_weight = None grad_in_proj_weight = None
...@@ -91,13 +131,39 @@ class MultiHeadAttention1DFunc(Function): ...@@ -91,13 +131,39 @@ class MultiHeadAttention1DFunc(Function):
output = output.to(torch.half) output = output.to(torch.half)
input = input.to(torch.half) input = input.to(torch.half)
input_mask = input_mask.to(torch.half) input_mask = input_mask.to(torch.half)
grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \ (
grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func( grad_input,
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, grad_in_proj_weight,
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
) = backward_func(
ctx.config.layer_id,
grad_output,
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, return (
grad_norm_weight, grad_norm_bias, None) grad_input,
None,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
None,
)
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
...@@ -122,8 +188,9 @@ class MultiHeadAttention(nn.Module): ...@@ -122,8 +188,9 @@ class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, self.config = Config(
fp16) batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16
)
check_config(self.config) check_config(self.config)
self.pg = pg self.pg = pg
self.pg_size = 1 self.pg_size = 1
...@@ -136,13 +203,17 @@ class MultiHeadAttention(nn.Module): ...@@ -136,13 +203,17 @@ class MultiHeadAttention(nn.Module):
global colossal_multihead_attention global colossal_multihead_attention
if colossal_multihead_attention is None: if colossal_multihead_attention is None:
from colossalai.kernel.op_builder import MultiHeadAttnBuilder from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load() multihead_attention = MultiHeadAttnBuilder().load()
colossal_multihead_attention = multihead_attention colossal_multihead_attention = multihead_attention
# create the layer in cuda kernels. # create the layer in cuda kernels.
cuda_module = colossal_multihead_attention cuda_module = colossal_multihead_attention
create_layer_func = (cuda_module.create_multihead_attention_fp16 create_layer_func = (
if self.config.fp16 else cuda_module.create_multihead_attention_fp32) cuda_module.create_multihead_attention_fp16
if self.config.fp16
else cuda_module.create_multihead_attention_fp32
)
create_layer_func( create_layer_func(
self.config.layer_id, self.config.layer_id,
...@@ -204,13 +275,15 @@ class MultiHeadAttention(nn.Module): ...@@ -204,13 +275,15 @@ class MultiHeadAttention(nn.Module):
with torch.no_grad(): with torch.no_grad():
self.in_proj_weight.copy_( self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[:, attn_qkvw_global.view(3, hs, hs)[
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), :
self.pg_size), :]) ]
)
self.in_proj_bias.copy_( self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[:, attn_qkvb_global.view(3, hs)[
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)
self.pg_size)]) ]
)
attn_ow_global = torch.empty(hs, hs) attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0) nn.init.xavier_uniform_(attn_ow_global, 1.0)
...@@ -218,9 +291,9 @@ class MultiHeadAttention(nn.Module): ...@@ -218,9 +291,9 @@ class MultiHeadAttention(nn.Module):
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
attn_ow_global = attn_ow_global.cpu() attn_ow_global = attn_ow_global.cpu()
with torch.no_grad(): with torch.no_grad():
self.out_proj_weight.copy_(attn_ow_global[:, self.out_proj_weight.copy_(
int(hs * rank_in_pg / attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)]
self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)]) )
else: else:
attn_qkvw = self.in_proj_weight.view(-1, hs) attn_qkvw = self.in_proj_weight.view(-1, hs)
...@@ -238,7 +311,7 @@ class MultiHeadAttention(nn.Module): ...@@ -238,7 +311,7 @@ class MultiHeadAttention(nn.Module):
self.config.training = self.training self.config.training = self.training
self.config.is_grad_enabled = torch.is_grad_enabled() self.config.is_grad_enabled = torch.is_grad_enabled()
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()) encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()
bs, sl, dim = hidden_states.size() bs, sl, dim = hidden_states.size()
if bs * sl > self.config.max_batch_tokens: if bs * sl > self.config.max_batch_tokens:
...@@ -250,8 +323,16 @@ class MultiHeadAttention(nn.Module): ...@@ -250,8 +323,16 @@ class MultiHeadAttention(nn.Module):
else: else:
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight, output = MultiHeadAttention1DFunc.apply(
self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, hidden_states,
self.norm_weight, self.norm_bias, self.config) encoder_padding_mask,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.norm_weight,
self.norm_bias,
self.config,
)
return output.to(self.precision) return output.to(self.precision)
...@@ -108,15 +108,16 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -108,15 +108,16 @@ class FusedScaleMaskSoftmax(nn.Module):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16 self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16 assert not (
and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, sq, sk] # [b, np, sq, sk]
...@@ -130,13 +131,14 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -130,13 +131,14 @@ class FusedScaleMaskSoftmax(nn.Module):
def is_kernel_available(self, mask, b, np, sq, sk): def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np attn_batches = b * np
if (self.scaled_masked_softmax_fusion # user want to fuse if (
and self.input_in_float16 # input must be fp16 self.scaled_masked_softmax_fusion # user want to fuse
and mask is not None # mask tensor must not be None and self.input_in_float16 # input must be fp16
and 16 < sk <= 2048 # sk must be 16 ~ 2048 and mask is not None # mask tensor must not be None
and sq % 4 == 0 # sq must be divisor of 4 and 16 < sk <= 2048 # sk must be 16 ~ 2048
and attn_batches % 4 == 0 # np * b must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
): and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048: if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np) batch_per_block = self.get_batch_per_block(sq, sk, b, np)
......
from .option import set_jit_fusion_options from .bias_dropout_add import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
from .bias_gelu import bias_gelu_impl from .bias_gelu import bias_gelu_impl
from .option import set_jit_fusion_options
__all__ = [ __all__ = [
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", "bias_dropout_add_fused_train",
"set_jit_fusion_options" "bias_dropout_add_fused_inference",
"bias_gelu_impl",
"set_jit_fusion_options",
] ]
import torch import torch
from torch import Tensor
def bias_dropout_add(x, bias, residual, prob, training): def bias_dropout_add(x, bias, residual, prob, training):
...@@ -10,16 +9,14 @@ def bias_dropout_add(x, bias, residual, prob, training): ...@@ -10,16 +9,14 @@ def bias_dropout_add(x, bias, residual, prob, training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor, def bias_dropout_add_fused_train(
bias: torch.Tensor, x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
residual: torch.Tensor, ) -> torch.Tensor:
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor, def bias_dropout_add_fused_inference(
bias: torch.Tensor, x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
residual: torch.Tensor, ) -> torch.Tensor:
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
...@@ -29,7 +29,6 @@ def bias_gelu_back(g, bias, y): ...@@ -29,7 +29,6 @@ def bias_gelu_back(g, bias, y):
class GeLUFunction(torch.autograd.Function): class GeLUFunction(torch.autograd.Function):
@staticmethod @staticmethod
# bias is an optional argument # bias is an optional argument
def forward(ctx, input, bias): def forward(ctx, input, bias):
......
...@@ -10,15 +10,14 @@ JIT_OPTIONS_SET = False ...@@ -10,15 +10,14 @@ JIT_OPTIONS_SET = False
def set_jit_fusion_options(): def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options. """Set PyTorch JIT layer fusion options."""
"""
# LSG: the latest pytorch and CUDA versions may not support # LSG: the latest pytorch and CUDA versions may not support
# the following jit settings # the following jit settings
global JIT_OPTIONS_SET global JIT_OPTIONS_SET
if JIT_OPTIONS_SET == False: if JIT_OPTIONS_SET == False:
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split(".")[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser # nvfuser
torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_executor(True)
...@@ -38,12 +37,14 @@ def set_jit_fusion_options(): ...@@ -38,12 +37,14 @@ def set_jit_fusion_options():
JIT_OPTIONS_SET = True JIT_OPTIONS_SET = True
def warmup_jit_fusion(batch_size: int, def warmup_jit_fusion(
hidden_size: int, batch_size: int,
seq_length: int = 512, hidden_size: int,
vocab_size: int = 32768, seq_length: int = 512,
dtype: torch.dtype = torch.float32): vocab_size: int = 32768,
""" Compile JIT functions before the main training steps """ dtype: torch.dtype = torch.float32,
):
"""Compile JIT functions before the main training steps"""
embed = Embedding(vocab_size, hidden_size).to(get_current_device()) embed = Embedding(vocab_size, hidden_size).to(get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
......
try: try:
import triton import triton
HAS_TRITON = True HAS_TRITON = True
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
...@@ -11,8 +12,14 @@ try: ...@@ -11,8 +12,14 @@ try:
from .token_attention_kernel import token_attention_fwd from .token_attention_kernel import token_attention_fwd
__all__ = [ __all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward", "llama_context_attn_fwd",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd" "bloom_context_attn_fwd",
"softmax",
"layer_norm",
"rmsnorm_forward",
"copy_kv_cache_to_dest",
"rotary_embedding_fwd",
"token_attention_fwd",
] ]
except ImportError: except ImportError:
......
import torch
import math import math
import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
...@@ -10,28 +13,42 @@ except ImportError: ...@@ -10,28 +13,42 @@ except ImportError:
if HAS_TRITON: if HAS_TRITON:
''' """
this function is modified from this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
''' """
@triton.jit @triton.jit
def _context_flash_attention_kernel( def _context_flash_attention_kernel(
Q, K, V, sm_scale, Q,
B_Start_Loc, B_Seqlen, K,
TMP, V,
sm_scale,
B_Start_Loc,
B_Seqlen,
TMP,
alibi_ptr, alibi_ptr,
Out, Out,
stride_qbs, stride_qh, stride_qd, stride_qbs,
stride_kbs, stride_kh, stride_kd, stride_qh,
stride_vbs, stride_vh, stride_vd, stride_qd,
stride_obs, stride_oh, stride_od, stride_kbs,
stride_tmp_b, stride_tmp_h, stride_tmp_s, stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_tmp_b,
stride_tmp_h,
stride_tmp_s,
# suggtest set-up 64, 128, 256, 512 # suggtest set-up 64, 128, 256, 512
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
batch_id = tl.program_id(0) batch_id = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
start_m = tl.program_id(2) start_m = tl.program_id(2)
...@@ -40,13 +57,18 @@ if HAS_TRITON: ...@@ -40,13 +57,18 @@ if HAS_TRITON:
offs_n = tl.arange(0, BLOCK_N) offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL) offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# get batch info # get batch info
cur_batch_seq_len = tl.load(B_Seqlen + batch_id) cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
cur_batch_start_index = tl.load(B_Start_Loc + batch_id) cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
block_start_loc = BLOCK_M * start_m block_start_loc = BLOCK_M * start_m
load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd load_p_ptrs = (
Q
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
...@@ -56,7 +78,7 @@ if HAS_TRITON: ...@@ -56,7 +78,7 @@ if HAS_TRITON:
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if alibi_ptr is not None: if alibi_ptr is not None:
alibi_m = tl.load(alibi_ptr + cur_head) alibi_m = tl.load(alibi_ptr + cur_head)
...@@ -64,8 +86,11 @@ if HAS_TRITON: ...@@ -64,8 +86,11 @@ if HAS_TRITON:
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, k = tl.load(
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
...@@ -95,21 +120,25 @@ if HAS_TRITON: ...@@ -95,21 +120,25 @@ if HAS_TRITON:
acc_scale = tl.load(t_ptrs) acc_scale = tl.load(t_ptrs)
acc = acc * acc_scale[:, None] acc = acc * acc_scale[:, None]
# update acc # update acc
v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, v = tl.load(
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype) p = p.to(v.dtype)
acc += tl.dot(p, v) acc += tl.dot(p, v)
# update m_i and l_i # update m_i and l_i
l_i = l_i_new l_i = l_i_new
m_i = m_i_new m_i = m_i_new
off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return return
@torch.no_grad() @torch.no_grad()
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
BLOCK = 128 BLOCK = 128
...@@ -129,17 +158,31 @@ if HAS_TRITON: ...@@ -129,17 +158,31 @@ if HAS_TRITON:
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid]( _context_flash_attention_kernel[grid](
q, k, v, sm_scale, q,
b_start_loc, b_seq_len, k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp, tmp,
alibi, alibi,
o, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(0),
k.stride(0), k.stride(1), k.stride(2), q.stride(1),
v.stride(0), v.stride(1), v.stride(2), q.stride(2),
o.stride(0), o.stride(1), o.stride(2), k.stride(0),
tmp.stride(0), tmp.stride(1), tmp.stride(2), k.stride(1),
# manually setting this blcok num, we can use tuning config to futher speed-up k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
...@@ -147,7 +190,7 @@ if HAS_TRITON: ...@@ -147,7 +190,7 @@ if HAS_TRITON:
num_stages=1, num_stages=1,
) )
return return
@torch.no_grad() @torch.no_grad()
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128 BLOCK = 128
...@@ -166,19 +209,34 @@ if HAS_TRITON: ...@@ -166,19 +209,34 @@ if HAS_TRITON:
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
# num_warps = 4 # num_warps = 4
_context_flash_attention_kernel[grid]( _context_flash_attention_kernel[grid](
q, k, v, sm_scale, b_start_loc, b_seq_len, q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp, tmp,
None, None,
o, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(0),
k.stride(0), k.stride(1), k.stride(2), q.stride(1),
v.stride(0), v.stride(1), v.stride(2), q.stride(2),
o.stride(0), o.stride(1), o.stride(2), k.stride(0),
tmp.stride(0), tmp.stride(1), tmp.stride(2), k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
) )
return return
\ No newline at end of file
...@@ -3,25 +3,28 @@ import torch ...@@ -3,25 +3,28 @@ import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
if HAS_TRITON: if HAS_TRITON:
@triton.jit @triton.jit
def _fwd_copy_kv_cache_dest( def _fwd_copy_kv_cache_dest(
kv_cache_ptr, dest_index_ptr, kv_cache_ptr,
dest_index_ptr,
out, out,
stride_k_bs, stride_k_bs,
stride_k_h, stride_k_h,
stride_k_d, stride_k_d,
stride_o_bs, stride_o_bs,
stride_o_h, stride_o_h,
stride_o_d, stride_o_d,
head_num, head_num,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr BLOCK_HEAD: tl.constexpr,
): ):
cur_index = tl.program_id(0) cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD) offs_h = tl.arange(0, BLOCK_HEAD)
...@@ -31,15 +34,14 @@ if HAS_TRITON: ...@@ -31,15 +34,14 @@ if HAS_TRITON:
cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
o_ptrs = out + dest_index * stride_o_bs + o_offsets o_ptrs = out + dest_index * stride_o_bs + o_offsets
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
return return
@torch.no_grad() @torch.no_grad()
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
seq_len = dest_index_ptr.shape[0] seq_len = dest_index_ptr.shape[0]
...@@ -47,16 +49,18 @@ if HAS_TRITON: ...@@ -47,16 +49,18 @@ if HAS_TRITON:
head_dim = k_ptr.shape[2] head_dim = k_ptr.shape[2]
assert head_num == out.shape[1], "head_num should be the same for k_ptr and out" assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2 num_warps = 2
_fwd_copy_kv_cache_dest[(seq_len,)]( _fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr, dest_index_ptr, out, k_ptr,
k_ptr.stride(0), dest_index_ptr,
k_ptr.stride(1), out,
k_ptr.stride(0),
k_ptr.stride(1),
k_ptr.stride(2), k_ptr.stride(2),
out.stride(0), out.stride(0),
out.stride(1), out.stride(1),
out.stride(2), out.stride(2),
head_num, head_num,
BLOCK_DMODEL=head_dim, BLOCK_DMODEL=head_dim,
...@@ -65,5 +69,3 @@ if HAS_TRITON: ...@@ -65,5 +69,3 @@ if HAS_TRITON:
num_stages=2, num_stages=2,
) )
return return
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
...@@ -14,13 +15,13 @@ if HAS_TRITON: ...@@ -14,13 +15,13 @@ if HAS_TRITON:
@triton.jit @triton.jit
def _layer_norm_fwd_fused( def _layer_norm_fwd_fused(
X, # pointer to the input X, # pointer to the input
Y, # pointer to the output Y, # pointer to the output
W, # pointer to the weights W, # pointer to the weights
B, # pointer to the biases B, # pointer to the biases
stride, # how much to increase the pointer when moving by 1 row stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X N, # number of columns in X
eps, # epsilon to avoid division by zero eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
...@@ -32,15 +33,15 @@ if HAS_TRITON: ...@@ -32,15 +33,15 @@ if HAS_TRITON:
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
_mean += a _mean += a
mean = tl.sum(_mean, axis=0) / N mean = tl.sum(_mean, axis=0) / N
# Compute variance # Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.) x = tl.where(cols < N, x - mean, 0.0)
_var += x * x _var += x * x
var = tl.sum(_var, axis=0) / N var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps) rstd = 1 / tl.sqrt(var + eps)
...@@ -50,7 +51,7 @@ if HAS_TRITON: ...@@ -50,7 +51,7 @@ if HAS_TRITON:
mask = cols < N mask = cols < N
w = tl.load(W + cols, mask=mask) w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask) b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - mean) * rstd x_hat = (x - mean) * rstd
y = x_hat * w + b y = x_hat * w + b
# Write output # Write output
...@@ -71,13 +72,7 @@ if HAS_TRITON: ...@@ -71,13 +72,7 @@ if HAS_TRITON:
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8) num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel # enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg, _layer_norm_fwd_fused[(M,)](
y, x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
weight, )
bias,
x_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
return y return y
import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
...@@ -9,9 +9,10 @@ except ImportError: ...@@ -9,9 +9,10 @@ except ImportError:
if HAS_TRITON: if HAS_TRITON:
''' """
this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
''' """
@triton.jit @triton.jit
def qkv_gemm_4d_kernel( def qkv_gemm_4d_kernel(
a_ptr, a_ptr,
...@@ -34,12 +35,12 @@ if HAS_TRITON: ...@@ -34,12 +35,12 @@ if HAS_TRITON:
stride_cn, stride_cn,
scale, scale,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M : tl.constexpr = 64, BLOCK_SIZE_M: tl.constexpr = 64,
BLOCK_SIZE_N : tl.constexpr = 32, BLOCK_SIZE_N: tl.constexpr = 32,
BLOCK_SIZE_K : tl.constexpr = 32, BLOCK_SIZE_K: tl.constexpr = 32,
GROUP_SIZE_M : tl.constexpr = 8, GROUP_SIZE_M: tl.constexpr = 8,
): ):
r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, r"""A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
Args: Args:
a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
...@@ -53,21 +54,21 @@ if HAS_TRITON: ...@@ -53,21 +54,21 @@ if HAS_TRITON:
stride_bh(tl.constexpr): stride for h-dimention for tensor array B stride_bh(tl.constexpr): stride for h-dimention for tensor array B
stride_bk(tl.constexpr): stride for k-dimention for tensor array B stride_bk(tl.constexpr): stride for k-dimention for tensor array B
stride_bn(tl.constexpr): stride for n-dimention for tensor array B stride_bn(tl.constexpr): stride for n-dimention for tensor array B
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
stride_ch(tl.constexpr): stride for h-dimention for tensor array output stride_ch(tl.constexpr): stride for h-dimention for tensor array output
stride_cm(tl.constexpr): stride for m-dimention for tensor array output stride_cm(tl.constexpr): stride for m-dimention for tensor array output
stride_cn(tl.constexpr): stride for n-dimention for tensor array output stride_cn(tl.constexpr): stride for n-dimention for tensor array output
BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
BLOCK_SIZE_K : tiling size for K-dimension of a and b BLOCK_SIZE_K : tiling size for K-dimension of a and b
GROUP_SIZE_M : group size for reducing cache miss, more details: GROUP_SIZE_M : group size for reducing cache miss, more details:
""" """
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
batch = tl.program_id(axis = 0) batch = tl.program_id(axis=0)
head = tl.program_id(axis = 1) head = tl.program_id(axis=1)
pid = tl.program_id(axis = 2) pid = tl.program_id(axis=2)
# the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
num_pid_in_group = GROUP_SIZE_M * num_pid_n num_pid_in_group = GROUP_SIZE_M * num_pid_n
...@@ -77,33 +78,38 @@ if HAS_TRITON: ...@@ -77,33 +78,38 @@ if HAS_TRITON:
pid_m = first_pid_m + (pid % group_size_m) pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + a_ptrs = (
(offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) a_ptr + batch * stride_ab + head * stride_ah + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + )
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) b_ptrs = (
b_ptr + batch * stride_bb + head * stride_bh + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K): for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.) a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.) b = tl.load(b_ptrs, mask=b_mask, other=0.0)
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
accumulator = accumulator.to(c_ptr.dtype.element_ty) accumulator = accumulator.to(c_ptr.dtype.element_ty)
if scale > 0: if scale > 0:
accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)
offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + c_ptrs = (
stride_cn * offs_accumu_n[None, :]) c_ptr
+ batch * stride_cb
+ head * stride_ch
+ stride_cm * offs_accumu_m[:, None]
+ stride_cn * offs_accumu_n[None, :]
)
accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
tl.store(c_ptrs, accumulator, mask=accumulator_mask) tl.store(c_ptrs, accumulator, mask=accumulator_mask)
...@@ -3,17 +3,19 @@ import torch ...@@ -3,17 +3,19 @@ import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
if HAS_TRITON: if HAS_TRITON:
''' """
this kernel function is modified from this kernel function is modified from
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
''' """
@triton.jit @triton.jit
def _rms_norm_fwd_fused( def _rms_norm_fwd_fused(
X, # pointer to the input X, # pointer to the input
...@@ -32,7 +34,7 @@ if HAS_TRITON: ...@@ -32,7 +34,7 @@ if HAS_TRITON:
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE): for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
_var += x * x _var += x * x
var = tl.sum(_var, axis=0) / N var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps) rstd = 1 / tl.sqrt(var + eps)
...@@ -41,13 +43,12 @@ if HAS_TRITON: ...@@ -41,13 +43,12 @@ if HAS_TRITON:
cols = off + tl.arange(0, BLOCK_SIZE) cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32) w = tl.load(W + cols, mask=mask).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = x * rstd x_hat = x * rstd
y = x_hat * w y = x_hat * w
# Write output # Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask) tl.store(Y + cols, y.to(tl.float16), mask=mask)
def rmsnorm_forward(x, weight, eps): def rmsnorm_forward(x, weight, eps):
# allocate output # allocate output
y = torch.empty_like(x) y = torch.empty_like(x)
...@@ -66,7 +67,5 @@ if HAS_TRITON: ...@@ -66,7 +67,5 @@ if HAS_TRITON:
BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
num_warps = 8 num_warps = 8
# enqueue kernel # enqueue kernel
_rms_norm_fwd_fused[(M,)](x_arg, y, weight, _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y return y
...@@ -29,19 +29,29 @@ def _rotary_kernel( ...@@ -29,19 +29,29 @@ def _rotary_kernel(
dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ off_q0 = (
None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride current_seq_range[:, None, None] * q_bs_stride
off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ + current_head_range[None, :, None] * q_h_stride
None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride + dim_range0[None, None, :] * q_d_stride
)
off_q1 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range1[None, None, :] * q_d_stride
)
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
q0 = tl.load(q + off_q0, q0 = tl.load(
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), q + off_q0,
other=0.0) mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
q1 = tl.load(q + off_q1, other=0.0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), )
other=0.0) q1 = tl.load(
q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
...@@ -49,12 +59,16 @@ def _rotary_kernel( ...@@ -49,12 +59,16 @@ def _rotary_kernel(
out0 = q0 * cos - q1 * sin out0 = q0 * cos - q1 * sin
out1 = q0 * sin + q1 * cos out1 = q0 * sin + q1 * cos
tl.store(q + off_q0, tl.store(
out0, q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) out0,
tl.store(q + off_q1, mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
out1, )
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) tl.store(
q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
return return
......
import torch import torch
from torch import nn
try: try:
import triton import triton
import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
...@@ -13,9 +12,10 @@ if HAS_TRITON: ...@@ -13,9 +12,10 @@ if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .qkv_matmul_kernel import qkv_gemm_4d_kernel
from .softmax import softmax_kernel from .softmax import softmax_kernel
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def self_attention_forward_without_fusion(
input_mask: torch.Tensor, scale: float): q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels ):
r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args: Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
...@@ -65,7 +65,7 @@ if HAS_TRITON: ...@@ -65,7 +65,7 @@ if HAS_TRITON:
score_output.stride(2), score_output.stride(2),
score_output.stride(3), score_output.stride(3),
scale=scale, scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting # currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64, BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32, BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32, BLOCK_SIZE_K=32,
...@@ -79,7 +79,6 @@ if HAS_TRITON: ...@@ -79,7 +79,6 @@ if HAS_TRITON:
n_rows, n_cols = score_output.shape n_rows, n_cols = score_output.shape
if n_rows <= 350000: if n_rows <= 350000:
block_size = max(triton.next_power_of_2(n_cols), 2) block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4 num_warps = 4
if block_size >= 4096: if block_size >= 4096:
...@@ -142,15 +141,9 @@ if HAS_TRITON: ...@@ -142,15 +141,9 @@ if HAS_TRITON:
) )
return output.view(batches, -1, d_model) return output.view(batches, -1, d_model)
def self_attention_compute_using_triton(qkv, def self_attention_compute_using_triton(
input_mask, qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False
layer_past, ):
alibi,
scale,
head_size,
triangular=False,
use_flash=False):
assert qkv.is_contiguous() assert qkv.is_contiguous()
assert alibi is None, "current triton self-attention does not support alibi" assert alibi is None, "current triton self-attention does not support alibi"
batches = qkv.shape[0] batches = qkv.shape[0]
...@@ -158,8 +151,8 @@ if HAS_TRITON: ...@@ -158,8 +151,8 @@ if HAS_TRITON:
num_of_heads = d_model // head_size num_of_heads = d_model // head_size
q = qkv[:, :, :d_model] q = qkv[:, :, :d_model]
k = qkv[:, :, d_model:d_model * 2] k = qkv[:, :, d_model : d_model * 2]
v = qkv[:, :, d_model * 2:] v = qkv[:, :, d_model * 2 :]
q = q.view(batches, -1, num_of_heads, head_size) q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size)
......
import torch import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
print("please install triton from https://github.com/openai/triton") print("please install triton from https://github.com/openai/triton")
if HAS_TRITON: if HAS_TRITON:
''' """
softmax kernel is modified based on softmax kernel is modified based on
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
''' """
@triton.jit @triton.jit
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
r""" the kernel function for implementing softmax operator r"""the kernel function for implementing softmax operator
Args: Args:
output_ptr: the output after finishing softmax operation, (N, hidden_dim) output_ptr: the output after finishing softmax operation, (N, hidden_dim)
input_ptr: the tensor of input, shape should be (N, hidden_dim) input_ptr: the tensor of input, shape should be (N, hidden_dim)
n_cols(tl.constexpr): the number of cols of input n_cols(tl.constexpr): the number of cols of input
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
""" """
row_idx = tl.program_id(0) row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * row_stride row_start_ptr = input_ptr + row_idx * row_stride
col_offsets = tl.arange(0, BLOCK_SIZE) col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0) row_minus_max = row - tl.max(row, axis=0)
if mask_ptr is not None: if mask_ptr is not None:
# load mask into SRAM # load mask into SRAM
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
# update # update
row_minus_max = row_minus_max + mask row_minus_max = row_minus_max + mask
numerator = tl.exp(row_minus_max) numerator = tl.exp(row_minus_max)
...@@ -43,17 +46,16 @@ if HAS_TRITON: ...@@ -43,17 +46,16 @@ if HAS_TRITON:
output_ptrs = output_row_start_ptr + col_offsets output_ptrs = output_row_start_ptr + col_offsets
# Write back output to DRAM # Write back output to DRAM
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
if mask is not None: if mask is not None:
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" assert dim == -1 or dim == len(input.shape) - 1, "currently softmax layer only support last dimention"
hidden_dim = input.shape[-1] hidden_dim = input.shape[-1]
output = torch.empty_like(input) output = torch.empty_like(input)
input = input.view(-1, hidden_dim) input = input.view(-1, hidden_dim)
if mask is not None: if mask is not None:
mask = mask.view(-1, hidden_dim) mask = mask.view(-1, hidden_dim)
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
...@@ -67,30 +69,31 @@ if HAS_TRITON: ...@@ -67,30 +69,31 @@ if HAS_TRITON:
else: else:
num_warps = 4 num_warps = 4
if num_rows <= 350000: if num_rows <= 350000:
grid = (num_rows,) grid = (num_rows,)
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) softmax_kernel[grid](
output, input, input.stride(0), num_cols, mask, BLOCK_SIZE=block_size, num_warps=num_warps
)
else: else:
grid = lambda meta: () grid = lambda meta: ()
grid = lambda meta: ( grid = lambda meta: (triton.cdiv(num_rows, meta["BLOCK_M"]),)
triton.cdiv(num_rows, meta["BLOCK_M"]),
)
BLOCK_M = 32
if block_size >= 4096: if block_size >= 4096:
BLOCK_M = 4 pass
elif block_size >= 2048: elif block_size >= 2048:
BLOCK_M = 8 pass
softmax_kernel[grid](output_ptr = output, softmax_kernel[grid](
input_ptr = input, output_ptr=output,
row_stride = input.stride(0), input_ptr=input,
n_rows = num_rows, row_stride=input.stride(0),
n_cols = num_cols, n_rows=num_rows,
mask_ptr = mask, n_cols=num_cols,
# currently manually setting up size mask_ptr=mask,
BLOCK_M = 32, # currently manually setting up size
BLOCK_SIZE = block_size) BLOCK_M=32,
BLOCK_SIZE=block_size,
)
return output return output
\ No newline at end of file
# Adapted from ModelTC https://github.com/ModelTC/lightllm # Adapted from ModelTC https://github.com/ModelTC/lightllm
import math
import torch import torch
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
HAS_TRITON = True HAS_TRITON = True
except ImportError: except ImportError:
HAS_TRITON = False HAS_TRITON = False
...@@ -15,10 +15,28 @@ except ImportError: ...@@ -15,10 +15,28 @@ except ImportError:
if HAS_TRITON: if HAS_TRITON:
@triton.jit @triton.jit
def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, def _token_attn_1_kernel(
attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride, Q,
q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride, K,
attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): sm_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0) current_batch = tl.program_id(0)
current_head = tl.program_id(1) current_head = tl.program_id(1)
start_n = tl.program_id(2) start_n = tl.program_id(2)
...@@ -40,9 +58,11 @@ if HAS_TRITON: ...@@ -40,9 +58,11 @@ if HAS_TRITON:
for start_mark in range(0, block_mask, 1): for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark) q = tl.load(Q + off_q + start_mark)
offs_n_new = current_batch_start_index + offs_n offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, k_loc = tl.load(
mask=offs_n_new < current_batch_end_index, kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
other=0) mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1) att_value = tl.sum(q[None, :] * k, 1)
...@@ -52,11 +72,29 @@ if HAS_TRITON: ...@@ -52,11 +72,29 @@ if HAS_TRITON:
return return
@triton.jit @triton.jit
def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, def _token_attn_1_alibi_kernel(
max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, Q,
q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride, K,
k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr, sm_scale,
BLOCK_N: tl.constexpr): alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0) current_batch = tl.program_id(0)
current_head = tl.program_id(1) current_head = tl.program_id(1)
start_n = tl.program_id(2) start_n = tl.program_id(2)
...@@ -79,9 +117,11 @@ if HAS_TRITON: ...@@ -79,9 +117,11 @@ if HAS_TRITON:
alibi_m = tl.load(alibi + current_head) alibi_m = tl.load(alibi + current_head)
q = tl.load(Q + off_q + start_mark) q = tl.load(Q + off_q + start_mark)
offs_n_new = current_batch_start_index + offs_n offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, k_loc = tl.load(
mask=offs_n_new < current_batch_end_index, kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
other=0) mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1) att_value = tl.sum(q[None, :] * k, 1)
...@@ -92,14 +132,9 @@ if HAS_TRITON: ...@@ -92,14 +132,9 @@ if HAS_TRITON:
return return
@torch.no_grad() @torch.no_grad()
def token_attn_fwd_1(q, def token_attn_fwd_1(
k, q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None
attn_out, ):
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
alibi=None):
BLOCK = 32 BLOCK = 32
# shape constraints # shape constraints
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
...@@ -168,9 +203,17 @@ if HAS_TRITON: ...@@ -168,9 +203,17 @@ if HAS_TRITON:
return return
@triton.jit @triton.jit
def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, def _token_attn_softmax_fwd(
logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride, softmax_logics,
BLOCK_SIZE: tl.constexpr): kv_cache_start_loc,
kv_cache_seqlen,
softmax_prob_out,
logics_head_dim_stride,
logics_batch_stride,
prob_head_dim_stride,
prob_batch_stride,
BLOCK_SIZE: tl.constexpr,
):
current_batch = tl.program_id(0) current_batch = tl.program_id(0)
current_head = tl.program_id(1) current_head = tl.program_id(1)
...@@ -178,20 +221,26 @@ if HAS_TRITON: ...@@ -178,20 +221,26 @@ if HAS_TRITON:
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
row = tl.load(softmax_logics + current_head * logics_head_dim_stride + row = tl.load(
(current_batch_in_all_start_index + col_offsets) * logics_batch_stride, softmax_logics
mask=col_offsets < current_batch_seq_len, + current_head * logics_head_dim_stride
other=-float('inf')).to(tl.float32) + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
mask=col_offsets < current_batch_seq_len,
other=-float("inf"),
).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0) row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max) numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0) denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator softmax_output = numerator / denominator
tl.store(softmax_prob_out + current_head * prob_head_dim_stride + tl.store(
(current_batch_in_all_start_index + col_offsets) * prob_batch_stride, softmax_prob_out
softmax_output, + current_head * prob_head_dim_stride
mask=col_offsets < current_batch_seq_len) + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
softmax_output,
mask=col_offsets < current_batch_seq_len,
)
return return
@torch.no_grad() @torch.no_grad()
...@@ -220,11 +269,27 @@ if HAS_TRITON: ...@@ -220,11 +269,27 @@ if HAS_TRITON:
return return
@triton.jit @triton.jit
def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, def _token_attn_2_kernel(
kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride, Prob,
v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride, V,
attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr, attn_out,
BLOCK_N: tl.constexpr): kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
prob_head_dim_stride,
prob_batch_stride,
v_batch_stride,
v_head_stride,
v_head_dim_stride,
attn_out_batch_stride,
attn_out_head_stride,
attn_out_head_dim_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0) current_batch = tl.program_id(0)
current_head = tl.program_id(1) current_head = tl.program_id(1)
...@@ -232,7 +297,6 @@ if HAS_TRITON: ...@@ -232,7 +297,6 @@ if HAS_TRITON:
offs_d = tl.arange(0, HEAD_DIM) offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = current_batch_seq_len
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
...@@ -242,19 +306,29 @@ if HAS_TRITON: ...@@ -242,19 +306,29 @@ if HAS_TRITON:
acc = tl.zeros([HEAD_DIM], dtype=tl.float32) acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for start_n in range(0, current_batch_seq_len, BLOCK_N): for start_n in range(0, current_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride, p_value = tl.load(
mask=(start_n + offs_n) < current_batch_seq_len, Prob + p_offs + start_n * kv_cache_loc_s_stride,
other=0.0) mask=(start_n + offs_n) < current_batch_seq_len,
v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, other=0.0,
mask=(start_n + offs_n) < current_batch_seq_len, )
other=0.0) v_loc = tl.load(
v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride, kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len, mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0) other=0.0,
)
v_value = tl.load(
V + v_offs + v_loc[:, None] * v_batch_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
other=0.0,
)
acc += tl.sum(p_value[:, None] * v_value, 0) acc += tl.sum(p_value[:, None] * v_value, 0)
acc = acc.to(tl.float16) acc = acc.to(tl.float16)
off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride off_o = (
current_batch * attn_out_batch_stride
+ current_head * attn_out_head_stride
+ offs_d * attn_out_head_dim_stride
)
out_ptrs = attn_out + off_o out_ptrs = attn_out + off_o
tl.store(out_ptrs, acc) tl.store(out_ptrs, acc)
return return
...@@ -296,15 +370,9 @@ if HAS_TRITON: ...@@ -296,15 +370,9 @@ if HAS_TRITON:
return return
@torch.no_grad() @torch.no_grad()
def token_attention_fwd(q, def token_attention_fwd(
k, q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None
v, ):
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=None):
head_num = k.shape[1] head_num = k.shape[1]
batch_size = kv_cache_seq_len.shape[0] batch_size = kv_cache_seq_len.shape[0]
calcu_shape1 = (batch_size, head_num, k.shape[2]) calcu_shape1 = (batch_size, head_num, k.shape[2])
...@@ -312,21 +380,24 @@ if HAS_TRITON: ...@@ -312,21 +380,24 @@ if HAS_TRITON:
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
token_attn_fwd_1(q.view(calcu_shape1), token_attn_fwd_1(
k, q.view(calcu_shape1),
att_m_tensor, k,
kv_cache_loc, att_m_tensor,
kv_cache_start_loc, kv_cache_loc,
kv_cache_seq_len, kv_cache_start_loc,
max_len_in_batch, kv_cache_seq_len,
alibi=alibi) max_len_in_batch,
alibi=alibi,
)
prob = torch.empty_like(att_m_tensor) prob = torch.empty_like(att_m_tensor)
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None att_m_tensor = None
token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, token_attn_fwd_2(
max_len_in_batch) prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
)
prob = None prob = None
......
from .lazy_init import LazyInitContext, LazyTensor from .lazy_init import LazyInitContext, LazyTensor
__all__ = [ __all__ = [
'LazyInitContext', "LazyInitContext",
'LazyTensor', "LazyTensor",
] ]
from contextlib import contextmanager
from types import MethodType from types import MethodType
from typing import Callable, Dict, Optional, Union from typing import Callable, Dict, Optional, Union
...@@ -35,43 +34,43 @@ _NO_META_FACTORY = [ ...@@ -35,43 +34,43 @@ _NO_META_FACTORY = [
"eye", "eye",
] ]
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] _EARLY_MATERIALIZED_OPS = ["__getitem__", "split"]
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) # If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. # without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data # These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__', 'numel', 'size', 'dim'] _CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"]
_LEGACY_TENSOR_CONSTRUCTOR = { _LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float, "FloatTensor": torch.float,
'DoubleTensor': torch.double, "DoubleTensor": torch.double,
'HalfTensor': torch.half, "HalfTensor": torch.half,
'BFloat16Tensor': torch.bfloat16, "BFloat16Tensor": torch.bfloat16,
'ByteTensor': torch.uint8, "ByteTensor": torch.uint8,
'CharTensor': torch.int8, "CharTensor": torch.int8,
'ShortTensor': torch.short, "ShortTensor": torch.short,
'IntTensor': torch.int, "IntTensor": torch.int,
'LongTensor': torch.long, "LongTensor": torch.long,
'BoolTensor': torch.bool, "BoolTensor": torch.bool,
} }
_EMPTY_DATA = torch.empty(0) _EMPTY_DATA = torch.empty(0)
class _MyTensor(Tensor): class _MyTensor(Tensor):
"""This class is only for correctness verification. """This class is only for correctness verification."""
"""
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
default_device: Optional[torch.device] = None default_device: Optional[torch.device] = None
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor': def __new__(cls, func, *args, concrete_data=None, **kwargs) -> "_MyTensor":
cls._pre_op_fn() cls._pre_op_fn()
if concrete_data is not None: if concrete_data is not None:
# uniform api as LazyTensor # uniform api as LazyTensor
data = concrete_data data = concrete_data
else: else:
kwargs['device'] = cls.default_device kwargs["device"] = cls.default_device
data = func(*args, **kwargs) data = func(*args, **kwargs)
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
...@@ -82,12 +81,11 @@ class _MyTensor(Tensor): ...@@ -82,12 +81,11 @@ class _MyTensor(Tensor):
def _data_tolist(tensor: torch.Tensor) -> list: def _data_tolist(tensor: torch.Tensor) -> list:
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor. """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor."""
"""
return tensor.data.tolist() return tensor.data.tolist()
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data. """Convert a lazy tensor's class to target's class, with target's data.
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
...@@ -104,7 +102,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: ...@@ -104,7 +102,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
tensor.__class__ = cls_to_become tensor.__class__ = cls_to_become
if cls_to_become is Parameter: if cls_to_become is Parameter:
# to fit UninitializedParameter # to fit UninitializedParameter
delattr(tensor, '_is_param') delattr(tensor, "_is_param")
tensor.data = target tensor.data = target
tensor.requires_grad = target.requires_grad tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method # subclass of torch.Tensor does not have tolist() method
...@@ -147,8 +145,8 @@ class LazyTensor(torch.Tensor): ...@@ -147,8 +145,8 @@ class LazyTensor(torch.Tensor):
""" """
_repr = True _repr = True
_meta_data: Optional[MetaTensor] = None # shape, dtype, device _meta_data: Optional[MetaTensor] = None # shape, dtype, device
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
default_device: Optional[torch.device] = None default_device: Optional[torch.device] = None
...@@ -159,8 +157,8 @@ class LazyTensor(torch.Tensor): ...@@ -159,8 +157,8 @@ class LazyTensor(torch.Tensor):
elem = concrete_data elem = concrete_data
else: else:
if meta_data is None: if meta_data is None:
device = kwargs.get('device', 'cpu') device = kwargs.get("device", "cpu")
elem = func(*args, **{**kwargs, 'device': 'meta'}) elem = func(*args, **{**kwargs, "device": "meta"})
meta_data = MetaTensor(elem, device=device) meta_data = MetaTensor(elem, device=device)
elem = meta_data._tensor elem = meta_data._tensor
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
...@@ -170,10 +168,10 @@ class LazyTensor(torch.Tensor): ...@@ -170,10 +168,10 @@ class LazyTensor(torch.Tensor):
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
if func.__name__ in _NORMAL_FACTORY: if func.__name__ in _NORMAL_FACTORY:
kwargs = {**kwargs, 'device': LazyTensor.default_device} kwargs = {**kwargs, "device": LazyTensor.default_device}
self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace) self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
def materialize(self) -> torch.Tensor: def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
...@@ -200,12 +198,11 @@ class LazyTensor(torch.Tensor): ...@@ -200,12 +198,11 @@ class LazyTensor(torch.Tensor):
return _convert_cls(self, local_tensor) return _convert_cls(self, local_tensor)
def clean(self) -> None: def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
""" delattr(self, "_factory_method")
delattr(self, '_factory_method') delattr(self, "_op_buffer")
delattr(self, '_op_buffer') delattr(self, "_materialized_data")
delattr(self, '_materialized_data') delattr(self, "_meta_data")
delattr(self, '_meta_data')
@staticmethod @staticmethod
def _replace_with_materialized(x): def _replace_with_materialized(x):
...@@ -221,8 +218,9 @@ class LazyTensor(torch.Tensor): ...@@ -221,8 +218,9 @@ class LazyTensor(torch.Tensor):
# apply cached sequence # apply cached sequence
self._pre_op_fn() self._pre_op_fn()
init_val = func(*tree_map(self._replace_with_materialized, args), init_val = func(
**tree_map(self._replace_with_materialized, kwargs)) *tree_map(self._replace_with_materialized, args), **tree_map(self._replace_with_materialized, kwargs)
)
self._materialized_data = self._rerun_ops(init_val) self._materialized_data = self._rerun_ops(init_val)
return self._materialized_data return self._materialized_data
...@@ -243,13 +241,13 @@ class LazyTensor(torch.Tensor): ...@@ -243,13 +241,13 @@ class LazyTensor(torch.Tensor):
packed = None packed = None
for (func, args, kwargs) in self._op_buffer: for func, args, kwargs in self._op_buffer:
if func == torch.Tensor.requires_grad_: if func == torch.Tensor.requires_grad_:
packed = func, args, kwargs # requires grad should be set at last packed = func, args, kwargs # requires grad should be set at last
else: else:
self._pre_op_fn() self._pre_op_fn()
o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) o = func(*tree_map(replace, args), **tree_map(replace, kwargs))
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
# super-dainiu: set requires_grad after all inplace-ops are done # super-dainiu: set requires_grad after all inplace-ops are done
if packed is not None: if packed is not None:
...@@ -268,8 +266,11 @@ class LazyTensor(torch.Tensor): ...@@ -268,8 +266,11 @@ class LazyTensor(torch.Tensor):
# These OPs cannot be lazy and related tensors should be early materialized # These OPs cannot be lazy and related tensors should be early materialized
tree_map(cls._replace_with_materialized, args) tree_map(cls._replace_with_materialized, args)
tree_map(cls._replace_with_materialized, kwargs) tree_map(cls._replace_with_materialized, kwargs)
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) is_inplace: bool = (
or func.__name__ in ('__setitem__', '__set__')) func.__name__.endswith("_")
and not (func.__name__.endswith("__"))
or func.__name__ in ("__setitem__", "__set__")
)
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
...@@ -285,11 +286,11 @@ class LazyTensor(torch.Tensor): ...@@ -285,11 +286,11 @@ class LazyTensor(torch.Tensor):
target: LazyTensor = args[0].clone() target: LazyTensor = args[0].clone()
target._op_buffer.append((func, args, kwargs)) target._op_buffer.append((func, args, kwargs))
target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]), target._meta_data = getattr(target._meta_data, func.name)(
**tree_map(unwrap, kwargs)) *tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs)
)
return target return target
else: else:
meta_to_lazy = {} meta_to_lazy = {}
def unwrap(x): def unwrap(x):
...@@ -328,10 +329,9 @@ class LazyTensor(torch.Tensor): ...@@ -328,10 +329,9 @@ class LazyTensor(torch.Tensor):
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass # skip pass # skip
def clone(self) -> "LazyTensor": def clone(self) -> "LazyTensor":
def factory_fn(): def factory_fn():
# if self is materialized, return self # if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self new_tensor = self.materialize() if type(self) is LazyTensor else self
...@@ -346,8 +346,10 @@ class LazyTensor(torch.Tensor): ...@@ -346,8 +346,10 @@ class LazyTensor(torch.Tensor):
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
if not self.is_leaf: if not self.is_leaf:
raise RuntimeError("Only Tensors created explicitly by the user " raise RuntimeError(
"(graph leaves) support the deepcopy protocol at the moment") "Only Tensors created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment"
)
if id(self) in memo: if id(self) in memo:
return memo[id(self)] return memo[id(self)]
...@@ -375,7 +377,7 @@ class LazyTensor(torch.Tensor): ...@@ -375,7 +377,7 @@ class LazyTensor(torch.Tensor):
return self return self
@data.setter @data.setter
def data(self, other: 'LazyTensor'): def data(self, other: "LazyTensor"):
"""This is sightly different from oringinal `data` setter. """This is sightly different from oringinal `data` setter.
E.g.: E.g.:
...@@ -413,7 +415,7 @@ class LazyTensor(torch.Tensor): ...@@ -413,7 +415,7 @@ class LazyTensor(torch.Tensor):
def __rpow__(self, other): def __rpow__(self, other):
dtype = torch.result_type(self, other) dtype = torch.result_type(self, other)
return torch.tensor(other, dtype=dtype, device=self.device)**self return torch.tensor(other, dtype=dtype, device=self.device) ** self
class LazyInitContext: class LazyInitContext:
...@@ -444,11 +446,14 @@ class LazyInitContext: ...@@ -444,11 +446,14 @@ class LazyInitContext:
1. Quantization strategies can be applied before allocating real memory. 1. Quantization strategies can be applied before allocating real memory.
2. Lazy initialization seems slower than normal initialization. 2. Lazy initialization seems slower than normal initialization.
""" """
_replaced: bool = False _replaced: bool = False
def __init__(self, def __init__(
tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, self,
default_device: Optional[Union[torch.device, str, int]] = None): tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor,
default_device: Optional[Union[torch.device, str, int]] = None,
):
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self.overrides = {} self.overrides = {}
self.tensor_cls = tensor_cls self.tensor_cls = tensor_cls
...@@ -457,7 +462,7 @@ class LazyInitContext: ...@@ -457,7 +462,7 @@ class LazyInitContext:
def __enter__(self): def __enter__(self):
if LazyInitContext._replaced: if LazyInitContext._replaced:
raise RuntimeError(f'LazyInitContext is not reentrant') raise RuntimeError(f"LazyInitContext is not reentrant")
LazyInitContext._replaced = True LazyInitContext._replaced = True
self.old_default_device = self.tensor_cls.default_device self.old_default_device = self.tensor_cls.default_device
self.tensor_cls.default_device = self.default_device self.tensor_cls.default_device = self.default_device
...@@ -485,17 +490,17 @@ class LazyInitContext: ...@@ -485,17 +490,17 @@ class LazyInitContext:
return args[0] return args[0]
elif len(args) == 1: elif len(args) == 1:
# (object data, *, torch.device device) # (object data, *, torch.device device)
kwargs = {**kwargs, 'dtype': dtype} kwargs = {**kwargs, "dtype": dtype}
replaced, orig = self.overrides['tensor'] replaced, orig = self.overrides["tensor"]
return replaced(*args, **kwargs) return replaced(*args, **kwargs)
elif _is_int_tuple(args): elif _is_int_tuple(args):
# (tuple of ints size, *, torch.device device) # (tuple of ints size, *, torch.device device)
kwargs = {**kwargs, 'dtype': dtype} kwargs = {**kwargs, "dtype": dtype}
replaced, orig = self.overrides['empty'] replaced, orig = self.overrides["empty"]
return replaced(*args, **kwargs) return replaced(*args, **kwargs)
else: else:
raise TypeError( raise TypeError(
f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)' f"new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)"
) )
return wrapper, target return wrapper, target
...@@ -514,23 +519,29 @@ class LazyInitContext: ...@@ -514,23 +519,29 @@ class LazyInitContext:
if callable(getattr(torch, target, None)) if callable(getattr(torch, target, None))
} }
self.overrides.update({ self.overrides.update(
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) {
for target in _NORMAL_FACTORY target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like"))
if callable(getattr(torch, target + '_like', None)) for target in _NORMAL_FACTORY
}) if callable(getattr(torch, target + "_like", None))
}
self.overrides.update({ )
target: wrap_legacy_constructor(getattr(torch, target), dtype)
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() self.overrides.update(
if callable(getattr(torch, target, None)) {
}) target: wrap_legacy_constructor(getattr(torch, target), dtype)
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
self.overrides.update({ if callable(getattr(torch, target, None))
target: wrap_no_meta_factory(getattr(torch, target)) }
for target in _NO_META_FACTORY )
if callable(getattr(torch, target, None))
}) self.overrides.update(
{
target: wrap_no_meta_factory(getattr(torch, target))
for target in _NO_META_FACTORY
if callable(getattr(torch, target, None))
}
)
for name, (wrapper, orig) in self.overrides.items(): for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, wrapper) setattr(torch, name, wrapper)
...@@ -556,10 +567,9 @@ class LazyInitContext: ...@@ -556,10 +567,9 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod @staticmethod
def distribute(module: nn.Module, def distribute(
device_mesh: DeviceMesh, module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False
sharding_spec_dict: Dict[str, ShardingSpec], ) -> nn.Module:
verbose: bool = False) -> nn.Module:
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args: Args:
...@@ -574,9 +584,9 @@ class LazyInitContext: ...@@ -574,9 +584,9 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
def _apply_to_lazy_module(module: nn.Module, def _apply_to_lazy_module(
apply_fn: Callable[[str, torch.Tensor], None], module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False
verbose: bool = False) -> nn.Module: ) -> nn.Module:
if verbose: if verbose:
# verbose info # verbose info
param_cnt = 0 param_cnt = 0
...@@ -590,7 +600,7 @@ def _apply_to_lazy_module(module: nn.Module, ...@@ -590,7 +600,7 @@ def _apply_to_lazy_module(module: nn.Module,
if verbose: if verbose:
param_cnt += 1 param_cnt += 1
total_numel += p.numel() total_numel += p.numel()
if getattr(p, '_materialized_data', False) is None: if getattr(p, "_materialized_data", False) is None:
# if no _materialized_data attr, the tensor is not lazy # if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt += 1 param_lazy_cnt += 1
else: else:
...@@ -612,10 +622,11 @@ def _apply_to_lazy_module(module: nn.Module, ...@@ -612,10 +622,11 @@ def _apply_to_lazy_module(module: nn.Module,
if verbose: if verbose:
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') _print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}")
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') _print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}")
_print_rank_0( _print_rank_0(
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%') f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%"
)
return module return module
......
from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
__all__ = [ __all__ = [
'launch', "launch",
'launch_from_openmpi', "launch_from_openmpi",
'launch_from_slurm', "launch_from_slurm",
'launch_from_torch', "launch_from_torch",
'initialize', "initialize",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment