Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -20,18 +20,18 @@ class Unpad(torch.autograd.Function):
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, 'b s ... -> (b s) ...')
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
......@@ -54,7 +54,7 @@ class Repad(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
......
......@@ -36,34 +36,64 @@ colossal_multihead_attention = None
@dataclass
class Config:
max_batch_tokens: int # max batch token numbers
max_seq_len: int # max sequence length
hidden_size: int # size of transformer hidden layers
nhead: int # number of heads in attention
attn_prob_dropout_ratio: float # attention score dropout ratio
hidden_dropout_ratio: float # dropout ration before residual
norm_first: bool # norm_first
fp16: bool # fp16 precision
max_batch_tokens: int # max batch token numbers
max_seq_len: int # max sequence length
hidden_size: int # size of transformer hidden layers
nhead: int # number of heads in attention
attn_prob_dropout_ratio: float # attention score dropout ratio
hidden_dropout_ratio: float # dropout ration before residual
norm_first: bool # norm_first
fp16: bool # fp16 precision
class MultiHeadAttention1DFunc(Function):
@staticmethod
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight,
norm_bias, config):
def forward(
ctx,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config,
):
cuda_module = colossal_multihead_attention
forward_func = (cuda_module.multihead_attention_fw_fp16
if config.fp16 else cuda_module.multihead_attention_fw_fp32)
forward_func = (
cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32
)
if config.fp16:
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first)
(output,) = forward_func(
config.layer_id,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config.training,
config.norm_first,
)
if config.is_grad_enabled and config.training:
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
out_proj_bias, norm_weight, norm_bias)
ctx.save_for_backward(
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
ctx.config = config
return output
......@@ -72,11 +102,21 @@ class MultiHeadAttention1DFunc(Function):
assert ctx.config.training
cuda_module = colossal_multihead_attention
backward_func = (cuda_module.multihead_attention_bw_fp16
if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32)
backward_func = (
cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32
)
output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \
out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors
(
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
) = ctx.saved_tensors
grad_input = None
grad_in_proj_weight = None
......@@ -91,13 +131,39 @@ class MultiHeadAttention1DFunc(Function):
output = output.to(torch.half)
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \
grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func(
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
(
grad_input,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
) = backward_func(
ctx.config.layer_id,
grad_output,
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias,
grad_norm_weight, grad_norm_bias, None)
return (
grad_input,
None,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
None,
)
class MultiHeadAttention(nn.Module):
......@@ -122,8 +188,9 @@ class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
super(MultiHeadAttention, self).__init__()
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first,
fp16)
self.config = Config(
batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16
)
check_config(self.config)
self.pg = pg
self.pg_size = 1
......@@ -136,13 +203,17 @@ class MultiHeadAttention(nn.Module):
global colossal_multihead_attention
if colossal_multihead_attention is None:
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
colossal_multihead_attention = multihead_attention
# create the layer in cuda kernels.
cuda_module = colossal_multihead_attention
create_layer_func = (cuda_module.create_multihead_attention_fp16
if self.config.fp16 else cuda_module.create_multihead_attention_fp32)
create_layer_func = (
cuda_module.create_multihead_attention_fp16
if self.config.fp16
else cuda_module.create_multihead_attention_fp32
)
create_layer_func(
self.config.layer_id,
......@@ -204,13 +275,15 @@ class MultiHeadAttention(nn.Module):
with torch.no_grad():
self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[:,
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size), :])
attn_qkvw_global.view(3, hs, hs)[
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), :
]
)
self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[:,
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size)])
attn_qkvb_global.view(3, hs)[
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)
]
)
attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0)
......@@ -218,9 +291,9 @@ class MultiHeadAttention(nn.Module):
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
attn_ow_global = attn_ow_global.cpu()
with torch.no_grad():
self.out_proj_weight.copy_(attn_ow_global[:,
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)])
self.out_proj_weight.copy_(
attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)]
)
else:
attn_qkvw = self.in_proj_weight.view(-1, hs)
......@@ -238,7 +311,7 @@ class MultiHeadAttention(nn.Module):
self.config.training = self.training
self.config.is_grad_enabled = torch.is_grad_enabled()
hidden_states = hidden_states.contiguous()
encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous())
encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()
bs, sl, dim = hidden_states.size()
if bs * sl > self.config.max_batch_tokens:
......@@ -250,8 +323,16 @@ class MultiHeadAttention(nn.Module):
else:
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight,
self.in_proj_bias, self.out_proj_weight, self.out_proj_bias,
self.norm_weight, self.norm_bias, self.config)
output = MultiHeadAttention1DFunc.apply(
hidden_states,
encoder_padding_mask,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.norm_weight,
self.norm_bias,
self.config,
)
return output.to(self.precision)
......@@ -108,15 +108,16 @@ class FusedScaleMaskSoftmax(nn.Module):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16
and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time."
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
......@@ -130,13 +131,14 @@ class FusedScaleMaskSoftmax(nn.Module):
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
......
from .option import set_jit_fusion_options
from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
from .bias_dropout_add import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
from .option import set_jit_fusion_options
__all__ = [
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl",
"set_jit_fusion_options"
"bias_dropout_add_fused_train",
"bias_dropout_add_fused_inference",
"bias_gelu_impl",
"set_jit_fusion_options",
]
import torch
from torch import Tensor
def bias_dropout_add(x, bias, residual, prob, training):
......@@ -10,16 +9,14 @@ def bias_dropout_add(x, bias, residual, prob, training):
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
def bias_dropout_add_fused_train(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
......@@ -29,7 +29,6 @@ def bias_gelu_back(g, bias, y):
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
......
......@@ -10,15 +10,14 @@ JIT_OPTIONS_SET = False
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options.
"""
"""Set PyTorch JIT layer fusion options."""
# LSG: the latest pytorch and CUDA versions may not support
# the following jit settings
global JIT_OPTIONS_SET
if JIT_OPTIONS_SET == False:
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
......@@ -38,12 +37,14 @@ def set_jit_fusion_options():
JIT_OPTIONS_SET = True
def warmup_jit_fusion(batch_size: int,
hidden_size: int,
seq_length: int = 512,
vocab_size: int = 32768,
dtype: torch.dtype = torch.float32):
""" Compile JIT functions before the main training steps """
def warmup_jit_fusion(
batch_size: int,
hidden_size: int,
seq_length: int = 512,
vocab_size: int = 32768,
dtype: torch.dtype = torch.float32,
):
"""Compile JIT functions before the main training steps"""
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
......
try:
import triton
HAS_TRITON = True
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
......@@ -11,8 +12,14 @@ try:
from .token_attention_kernel import token_attention_fwd
__all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
"llama_context_attn_fwd",
"bloom_context_attn_fwd",
"softmax",
"layer_norm",
"rmsnorm_forward",
"copy_kv_cache_to_dest",
"rotary_embedding_fwd",
"token_attention_fwd",
]
except ImportError:
......
import torch
import math
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
......@@ -10,28 +13,42 @@ except ImportError:
if HAS_TRITON:
'''
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
'''
"""
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
"""
@triton.jit
def _context_flash_attention_kernel(
Q, K, V, sm_scale,
B_Start_Loc, B_Seqlen,
TMP,
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
TMP,
alibi_ptr,
Out,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
stride_vbs, stride_vh, stride_vd,
stride_obs, stride_oh, stride_od,
stride_tmp_b, stride_tmp_h, stride_tmp_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_tmp_b,
stride_tmp_h,
stride_tmp_s,
# suggtest set-up 64, 128, 256, 512
BLOCK_M: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
batch_id = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
......@@ -40,13 +57,18 @@ if HAS_TRITON:
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# get batch info
# get batch info
cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
block_start_loc = BLOCK_M * start_m
load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
load_p_ptrs = (
Q
+ (cur_batch_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
......@@ -56,7 +78,7 @@ if HAS_TRITON:
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if alibi_ptr is not None:
alibi_m = tl.load(alibi_ptr + cur_head)
......@@ -64,8 +86,11 @@ if HAS_TRITON:
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
k = tl.load(
k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
......@@ -95,21 +120,25 @@ if HAS_TRITON:
acc_scale = tl.load(t_ptrs)
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
v = tl.load(
v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
off_o = (
(cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
@torch.no_grad()
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
BLOCK = 128
......@@ -129,17 +158,31 @@ if HAS_TRITON:
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
_context_flash_attention_kernel[grid](
q, k, v, sm_scale,
b_start_loc, b_seq_len,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
alibi,
o,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
tmp.stride(0), tmp.stride(1), tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
# manually setting this blcok num, we can use tuning config to futher speed-up
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
......@@ -147,7 +190,7 @@ if HAS_TRITON:
num_stages=1,
)
return
@torch.no_grad()
def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128
......@@ -166,19 +209,34 @@ if HAS_TRITON:
num_warps = 4 if Lk <= 64 else 8
# num_warps = 4
_context_flash_attention_kernel[grid](
q, k, v, sm_scale, b_start_loc, b_seq_len,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
tmp,
None,
o,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
o.stride(0), o.stride(1), o.stride(2),
tmp.stride(0), tmp.stride(1), tmp.stride(2),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
tmp.stride(0),
tmp.stride(1),
tmp.stride(2),
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
\ No newline at end of file
return
......@@ -3,25 +3,28 @@ import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
@triton.jit
def _fwd_copy_kv_cache_dest(
kv_cache_ptr, dest_index_ptr,
kv_cache_ptr,
dest_index_ptr,
out,
stride_k_bs,
stride_k_h,
stride_k_bs,
stride_k_h,
stride_k_d,
stride_o_bs,
stride_o_h,
stride_o_bs,
stride_o_h,
stride_o_d,
head_num,
BLOCK_DMODEL: tl.constexpr,
BLOCK_HEAD: tl.constexpr
BLOCK_HEAD: tl.constexpr,
):
cur_index = tl.program_id(0)
offs_h = tl.arange(0, BLOCK_HEAD)
......@@ -31,15 +34,14 @@ if HAS_TRITON:
cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
o_ptrs = out + dest_index * stride_o_bs + o_offsets
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
return
@torch.no_grad()
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
seq_len = dest_index_ptr.shape[0]
......@@ -47,16 +49,18 @@ if HAS_TRITON:
head_dim = k_ptr.shape[2]
assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2
_fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr, dest_index_ptr, out,
k_ptr.stride(0),
k_ptr.stride(1),
k_ptr,
dest_index_ptr,
out,
k_ptr.stride(0),
k_ptr.stride(1),
k_ptr.stride(2),
out.stride(0),
out.stride(1),
out.stride(0),
out.stride(1),
out.stride(2),
head_num,
BLOCK_DMODEL=head_dim,
......@@ -65,5 +69,3 @@ if HAS_TRITON:
num_stages=2,
)
return
......@@ -3,6 +3,7 @@ import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
......@@ -14,13 +15,13 @@ if HAS_TRITON:
@triton.jit
def _layer_norm_fwd_fused(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
......@@ -32,15 +33,15 @@ if HAS_TRITON:
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.0)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
......@@ -50,7 +51,7 @@ if HAS_TRITON:
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
......@@ -71,13 +72,7 @@ if HAS_TRITON:
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg,
y,
weight,
bias,
x_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
_layer_norm_fwd_fused[(M,)](
x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
return y
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
......@@ -9,9 +9,10 @@ except ImportError:
if HAS_TRITON:
'''
"""
this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
'''
"""
@triton.jit
def qkv_gemm_4d_kernel(
a_ptr,
......@@ -34,12 +35,12 @@ if HAS_TRITON:
stride_cn,
scale,
# Meta-parameters
BLOCK_SIZE_M : tl.constexpr = 64,
BLOCK_SIZE_N : tl.constexpr = 32,
BLOCK_SIZE_K : tl.constexpr = 32,
GROUP_SIZE_M : tl.constexpr = 8,
BLOCK_SIZE_M: tl.constexpr = 64,
BLOCK_SIZE_N: tl.constexpr = 32,
BLOCK_SIZE_K: tl.constexpr = 32,
GROUP_SIZE_M: tl.constexpr = 8,
):
r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
r"""A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
Args:
a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
......@@ -53,21 +54,21 @@ if HAS_TRITON:
stride_bh(tl.constexpr): stride for h-dimention for tensor array B
stride_bk(tl.constexpr): stride for k-dimention for tensor array B
stride_bn(tl.constexpr): stride for n-dimention for tensor array B
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
stride_ch(tl.constexpr): stride for h-dimention for tensor array output
stride_cm(tl.constexpr): stride for m-dimention for tensor array output
stride_cn(tl.constexpr): stride for n-dimention for tensor array output
BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
BLOCK_SIZE_K : tiling size for K-dimension of a and b
GROUP_SIZE_M : group size for reducing cache miss, more details:
GROUP_SIZE_M : group size for reducing cache miss, more details:
"""
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
batch = tl.program_id(axis = 0)
head = tl.program_id(axis = 1)
pid = tl.program_id(axis = 2)
batch = tl.program_id(axis=0)
head = tl.program_id(axis=1)
pid = tl.program_id(axis=2)
# the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
num_pid_in_group = GROUP_SIZE_M * num_pid_n
......@@ -77,33 +78,38 @@ if HAS_TRITON:
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
(offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
a_ptrs = (
a_ptr + batch * stride_ab + head * stride_ah + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
)
b_ptrs = (
b_ptr + batch * stride_bb + head * stride_bh + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.)
b = tl.load(b_ptrs, mask=b_mask, other=0.)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
accumulator = accumulator.to(c_ptr.dtype.element_ty)
if scale > 0:
accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)
offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
stride_cn * offs_accumu_n[None, :])
c_ptrs = (
c_ptr
+ batch * stride_cb
+ head * stride_ch
+ stride_cm * offs_accumu_m[:, None]
+ stride_cn * offs_accumu_n[None, :]
)
accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
tl.store(c_ptrs, accumulator, mask=accumulator_mask)
......@@ -3,17 +3,19 @@ import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
this kernel function is modified from
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
'''
"""
this kernel function is modified from
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
"""
@triton.jit
def _rms_norm_fwd_fused(
X, # pointer to the input
......@@ -32,7 +34,7 @@ if HAS_TRITON:
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
......@@ -41,13 +43,12 @@ if HAS_TRITON:
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
def rmsnorm_forward(x, weight, eps):
# allocate output
y = torch.empty_like(x)
......@@ -66,7 +67,5 @@ if HAS_TRITON:
BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
num_warps = 8
# enqueue kernel
_rms_norm_fwd_fused[(M,)](x_arg, y, weight,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
_rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y
......@@ -29,19 +29,29 @@ def _rotary_kernel(
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride
off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride
off_q0 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range0[None, None, :] * q_d_stride
)
off_q1 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range1[None, None, :] * q_d_stride
)
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
q0 = tl.load(q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0)
q1 = tl.load(q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0)
q0 = tl.load(
q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
q1 = tl.load(
q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
......@@ -49,12 +59,16 @@ def _rotary_kernel(
out0 = q0 * cos - q1 * sin
out1 = q0 * sin + q1 * cos
tl.store(q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
tl.store(q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
tl.store(
q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
tl.store(
q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
return
......
import torch
from torch import nn
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
......@@ -13,9 +12,10 @@ if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
from .softmax import softmax_kernel
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
input_mask: torch.Tensor, scale: float):
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
def self_attention_forward_without_fusion(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float
):
r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
......@@ -65,7 +65,7 @@ if HAS_TRITON:
score_output.stride(2),
score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
......@@ -79,7 +79,6 @@ if HAS_TRITON:
n_rows, n_cols = score_output.shape
if n_rows <= 350000:
block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if block_size >= 4096:
......@@ -142,15 +141,9 @@ if HAS_TRITON:
)
return output.view(batches, -1, d_model)
def self_attention_compute_using_triton(qkv,
input_mask,
layer_past,
alibi,
scale,
head_size,
triangular=False,
use_flash=False):
def self_attention_compute_using_triton(
qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False
):
assert qkv.is_contiguous()
assert alibi is None, "current triton self-attention does not support alibi"
batches = qkv.shape[0]
......@@ -158,8 +151,8 @@ if HAS_TRITON:
num_of_heads = d_model // head_size
q = qkv[:, :, :d_model]
k = qkv[:, :, d_model:d_model * 2]
v = qkv[:, :, d_model * 2:]
k = qkv[:, :, d_model : d_model * 2]
v = qkv[:, :, d_model * 2 :]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)
......
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
'''
softmax kernel is modified based on
"""
softmax kernel is modified based on
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
'''
"""
@triton.jit
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
r""" the kernel function for implementing softmax operator
r"""the kernel function for implementing softmax operator
Args:
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
input_ptr: the tensor of input, shape should be (N, hidden_dim)
n_cols(tl.constexpr): the number of cols of input
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
"""
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
if mask_ptr is not None:
# load mask into SRAM
# load mask into SRAM
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
# update
# update
row_minus_max = row_minus_max + mask
numerator = tl.exp(row_minus_max)
......@@ -43,17 +46,16 @@ if HAS_TRITON:
output_ptrs = output_row_start_ptr + col_offsets
# Write back output to DRAM
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
if mask is not None:
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
assert dim == -1 or dim == len(input.shape) - 1, "currently softmax layer only support last dimention"
hidden_dim = input.shape[-1]
output = torch.empty_like(input)
input = input.view(-1, hidden_dim)
if mask is not None:
if mask is not None:
mask = mask.view(-1, hidden_dim)
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
......@@ -67,30 +69,31 @@ if HAS_TRITON:
else:
num_warps = 4
if num_rows <= 350000:
if num_rows <= 350000:
grid = (num_rows,)
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
softmax_kernel[grid](
output, input, input.stride(0), num_cols, mask, BLOCK_SIZE=block_size, num_warps=num_warps
)
else:
grid = lambda meta: ()
grid = lambda meta: (
triton.cdiv(num_rows, meta["BLOCK_M"]),
)
grid = lambda meta: (triton.cdiv(num_rows, meta["BLOCK_M"]),)
BLOCK_M = 32
if block_size >= 4096:
BLOCK_M = 4
pass
elif block_size >= 2048:
BLOCK_M = 8
pass
softmax_kernel[grid](output_ptr = output,
input_ptr = input,
row_stride = input.stride(0),
n_rows = num_rows,
n_cols = num_cols,
mask_ptr = mask,
# currently manually setting up size
BLOCK_M = 32,
BLOCK_SIZE = block_size)
softmax_kernel[grid](
output_ptr=output,
input_ptr=input,
row_stride=input.stride(0),
n_rows=num_rows,
n_cols=num_cols,
mask_ptr=mask,
# currently manually setting up size
BLOCK_M=32,
BLOCK_SIZE=block_size,
)
return output
\ No newline at end of file
return output
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import math
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
......@@ -15,10 +15,28 @@ except ImportError:
if HAS_TRITON:
@triton.jit
def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride,
q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride,
attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr):
def _token_attn_1_kernel(
Q,
K,
sm_scale,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
......@@ -40,9 +58,11 @@ if HAS_TRITON:
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0)
k_loc = tl.load(
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1)
......@@ -52,11 +72,29 @@ if HAS_TRITON:
return
@triton.jit
def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen,
max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride,
q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride,
k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr):
def _token_attn_1_alibi_kernel(
Q,
K,
sm_scale,
alibi,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
attn_out,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
q_batch_stride,
q_head_stride,
q_head_dim_stride,
k_batch_stride,
k_head_stride,
k_head_dim_stride,
attn_head_stride,
attn_batch_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
start_n = tl.program_id(2)
......@@ -79,9 +117,11 @@ if HAS_TRITON:
alibi_m = tl.load(alibi + current_head)
q = tl.load(Q + off_q + start_mark)
offs_n_new = current_batch_start_index + offs_n
k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0)
k_loc = tl.load(
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
mask=offs_n_new < current_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
att_value = tl.sum(q[None, :] * k, 1)
......@@ -92,14 +132,9 @@ if HAS_TRITON:
return
@torch.no_grad()
def token_attn_fwd_1(q,
k,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
alibi=None):
def token_attn_fwd_1(
q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None
):
BLOCK = 32
# shape constraints
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
......@@ -168,9 +203,17 @@ if HAS_TRITON:
return
@triton.jit
def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out,
logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride,
BLOCK_SIZE: tl.constexpr):
def _token_attn_softmax_fwd(
softmax_logics,
kv_cache_start_loc,
kv_cache_seqlen,
softmax_prob_out,
logics_head_dim_stride,
logics_batch_stride,
prob_head_dim_stride,
prob_batch_stride,
BLOCK_SIZE: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
......@@ -178,20 +221,26 @@ if HAS_TRITON:
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
row = tl.load(softmax_logics + current_head * logics_head_dim_stride +
(current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
mask=col_offsets < current_batch_seq_len,
other=-float('inf')).to(tl.float32)
row = tl.load(
softmax_logics
+ current_head * logics_head_dim_stride
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
mask=col_offsets < current_batch_seq_len,
other=-float("inf"),
).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
tl.store(softmax_prob_out + current_head * prob_head_dim_stride +
(current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
softmax_output,
mask=col_offsets < current_batch_seq_len)
tl.store(
softmax_prob_out
+ current_head * prob_head_dim_stride
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
softmax_output,
mask=col_offsets < current_batch_seq_len,
)
return
@torch.no_grad()
......@@ -220,11 +269,27 @@ if HAS_TRITON:
return
@triton.jit
def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride,
v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride,
attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr):
def _token_attn_2_kernel(
Prob,
V,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seqlen,
max_kv_cache_len,
kv_cache_loc_b_stride,
kv_cache_loc_s_stride,
prob_head_dim_stride,
prob_batch_stride,
v_batch_stride,
v_head_stride,
v_head_dim_stride,
attn_out_batch_stride,
attn_out_head_stride,
attn_out_head_dim_stride,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
current_batch = tl.program_id(0)
current_head = tl.program_id(1)
......@@ -232,7 +297,6 @@ if HAS_TRITON:
offs_d = tl.arange(0, HEAD_DIM)
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
current_batch_end_index = current_batch_seq_len
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
......@@ -242,19 +306,29 @@ if HAS_TRITON:
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
for start_n in range(0, current_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0)
v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0)
v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
other=0.0)
p_value = tl.load(
Prob + p_offs + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0,
)
v_loc = tl.load(
kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
mask=(start_n + offs_n) < current_batch_seq_len,
other=0.0,
)
v_value = tl.load(
V + v_offs + v_loc[:, None] * v_batch_stride,
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
other=0.0,
)
acc += tl.sum(p_value[:, None] * v_value, 0)
acc = acc.to(tl.float16)
off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride
off_o = (
current_batch * attn_out_batch_stride
+ current_head * attn_out_head_stride
+ offs_d * attn_out_head_dim_stride
)
out_ptrs = attn_out + off_o
tl.store(out_ptrs, acc)
return
......@@ -296,15 +370,9 @@ if HAS_TRITON:
return
@torch.no_grad()
def token_attention_fwd(q,
k,
v,
attn_out,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=None):
def token_attention_fwd(
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None
):
head_num = k.shape[1]
batch_size = kv_cache_seq_len.shape[0]
calcu_shape1 = (batch_size, head_num, k.shape[2])
......@@ -312,21 +380,24 @@ if HAS_TRITON:
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
token_attn_fwd_1(q.view(calcu_shape1),
k,
att_m_tensor,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=alibi)
token_attn_fwd_1(
q.view(calcu_shape1),
k,
att_m_tensor,
kv_cache_loc,
kv_cache_start_loc,
kv_cache_seq_len,
max_len_in_batch,
alibi=alibi,
)
prob = torch.empty_like(att_m_tensor)
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None
token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len,
max_len_in_batch)
token_attn_fwd_2(
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
)
prob = None
......
from .lazy_init import LazyInitContext, LazyTensor
__all__ = [
'LazyInitContext',
'LazyTensor',
"LazyInitContext",
"LazyTensor",
]
from contextlib import contextmanager
from types import MethodType
from typing import Callable, Dict, Optional, Union
......@@ -35,43 +34,43 @@ _NO_META_FACTORY = [
"eye",
]
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
_EARLY_MATERIALIZED_OPS = ["__getitem__", "split"]
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__', 'numel', 'size', 'dim']
_CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"]
_LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float,
'DoubleTensor': torch.double,
'HalfTensor': torch.half,
'BFloat16Tensor': torch.bfloat16,
'ByteTensor': torch.uint8,
'CharTensor': torch.int8,
'ShortTensor': torch.short,
'IntTensor': torch.int,
'LongTensor': torch.long,
'BoolTensor': torch.bool,
"FloatTensor": torch.float,
"DoubleTensor": torch.double,
"HalfTensor": torch.half,
"BFloat16Tensor": torch.bfloat16,
"ByteTensor": torch.uint8,
"CharTensor": torch.int8,
"ShortTensor": torch.short,
"IntTensor": torch.int,
"LongTensor": torch.long,
"BoolTensor": torch.bool,
}
_EMPTY_DATA = torch.empty(0)
class _MyTensor(Tensor):
"""This class is only for correctness verification.
"""
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
"""This class is only for correctness verification."""
_pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
default_device: Optional[torch.device] = None
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> "_MyTensor":
cls._pre_op_fn()
if concrete_data is not None:
# uniform api as LazyTensor
data = concrete_data
else:
kwargs['device'] = cls.default_device
kwargs["device"] = cls.default_device
data = func(*args, **kwargs)
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
......@@ -82,12 +81,11 @@ class _MyTensor(Tensor):
def _data_tolist(tensor: torch.Tensor) -> list:
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
"""
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor."""
return tensor.data.tolist()
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
......@@ -104,7 +102,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
tensor.__class__ = cls_to_become
if cls_to_become is Parameter:
# to fit UninitializedParameter
delattr(tensor, '_is_param')
delattr(tensor, "_is_param")
tensor.data = target
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
......@@ -147,8 +145,8 @@ class LazyTensor(torch.Tensor):
"""
_repr = True
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
_pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
default_device: Optional[torch.device] = None
......@@ -159,8 +157,8 @@ class LazyTensor(torch.Tensor):
elem = concrete_data
else:
if meta_data is None:
device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'})
device = kwargs.get("device", "cpu")
elem = func(*args, **{**kwargs, "device": "meta"})
meta_data = MetaTensor(elem, device=device)
elem = meta_data._tensor
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
......@@ -170,10 +168,10 @@ class LazyTensor(torch.Tensor):
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
if func.__name__ in _NORMAL_FACTORY:
kwargs = {**kwargs, 'device': LazyTensor.default_device}
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
kwargs = {**kwargs, "device": LazyTensor.default_device}
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
......@@ -200,12 +198,11 @@ class LazyTensor(torch.Tensor):
return _convert_cls(self, local_tensor)
def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
"""
delattr(self, '_factory_method')
delattr(self, '_op_buffer')
delattr(self, '_materialized_data')
delattr(self, '_meta_data')
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
delattr(self, "_factory_method")
delattr(self, "_op_buffer")
delattr(self, "_materialized_data")
delattr(self, "_meta_data")
@staticmethod
def _replace_with_materialized(x):
......@@ -221,8 +218,9 @@ class LazyTensor(torch.Tensor):
# apply cached sequence
self._pre_op_fn()
init_val = func(*tree_map(self._replace_with_materialized, args),
**tree_map(self._replace_with_materialized, kwargs))
init_val = func(
*tree_map(self._replace_with_materialized, args), **tree_map(self._replace_with_materialized, kwargs)
)
self._materialized_data = self._rerun_ops(init_val)
return self._materialized_data
......@@ -243,13 +241,13 @@ class LazyTensor(torch.Tensor):
packed = None
for (func, args, kwargs) in self._op_buffer:
for func, args, kwargs in self._op_buffer:
if func == torch.Tensor.requires_grad_:
packed = func, args, kwargs # requires grad should be set at last
packed = func, args, kwargs # requires grad should be set at last
else:
self._pre_op_fn()
o = func(*tree_map(replace, args), **tree_map(replace, kwargs))
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
# super-dainiu: set requires_grad after all inplace-ops are done
if packed is not None:
......@@ -268,8 +266,11 @@ class LazyTensor(torch.Tensor):
# These OPs cannot be lazy and related tensors should be early materialized
tree_map(cls._replace_with_materialized, args)
tree_map(cls._replace_with_materialized, kwargs)
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ in ('__setitem__', '__set__'))
is_inplace: bool = (
func.__name__.endswith("_")
and not (func.__name__.endswith("__"))
or func.__name__ in ("__setitem__", "__set__")
)
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
......@@ -285,11 +286,11 @@ class LazyTensor(torch.Tensor):
target: LazyTensor = args[0].clone()
target._op_buffer.append((func, args, kwargs))
target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]),
**tree_map(unwrap, kwargs))
target._meta_data = getattr(target._meta_data, func.name)(
*tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs)
)
return target
else:
meta_to_lazy = {}
def unwrap(x):
......@@ -328,10 +329,9 @@ class LazyTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass # skip
pass # skip
def clone(self) -> "LazyTensor":
def factory_fn():
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
......@@ -346,8 +346,10 @@ class LazyTensor(torch.Tensor):
def __deepcopy__(self, memo):
if not self.is_leaf:
raise RuntimeError("Only Tensors created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment")
raise RuntimeError(
"Only Tensors created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment"
)
if id(self) in memo:
return memo[id(self)]
......@@ -375,7 +377,7 @@ class LazyTensor(torch.Tensor):
return self
@data.setter
def data(self, other: 'LazyTensor'):
def data(self, other: "LazyTensor"):
"""This is sightly different from oringinal `data` setter.
E.g.:
......@@ -413,7 +415,7 @@ class LazyTensor(torch.Tensor):
def __rpow__(self, other):
dtype = torch.result_type(self, other)
return torch.tensor(other, dtype=dtype, device=self.device)**self
return torch.tensor(other, dtype=dtype, device=self.device) ** self
class LazyInitContext:
......@@ -444,11 +446,14 @@ class LazyInitContext:
1. Quantization strategies can be applied before allocating real memory.
2. Lazy initialization seems slower than normal initialization.
"""
_replaced: bool = False
def __init__(self,
tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor,
default_device: Optional[Union[torch.device, str, int]] = None):
def __init__(
self,
tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor,
default_device: Optional[Union[torch.device, str, int]] = None,
):
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self.overrides = {}
self.tensor_cls = tensor_cls
......@@ -457,7 +462,7 @@ class LazyInitContext:
def __enter__(self):
if LazyInitContext._replaced:
raise RuntimeError(f'LazyInitContext is not reentrant')
raise RuntimeError(f"LazyInitContext is not reentrant")
LazyInitContext._replaced = True
self.old_default_device = self.tensor_cls.default_device
self.tensor_cls.default_device = self.default_device
......@@ -485,17 +490,17 @@ class LazyInitContext:
return args[0]
elif len(args) == 1:
# (object data, *, torch.device device)
kwargs = {**kwargs, 'dtype': dtype}
replaced, orig = self.overrides['tensor']
kwargs = {**kwargs, "dtype": dtype}
replaced, orig = self.overrides["tensor"]
return replaced(*args, **kwargs)
elif _is_int_tuple(args):
# (tuple of ints size, *, torch.device device)
kwargs = {**kwargs, 'dtype': dtype}
replaced, orig = self.overrides['empty']
kwargs = {**kwargs, "dtype": dtype}
replaced, orig = self.overrides["empty"]
return replaced(*args, **kwargs)
else:
raise TypeError(
f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)'
f"new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)"
)
return wrapper, target
......@@ -514,23 +519,29 @@ class LazyInitContext:
if callable(getattr(torch, target, None))
}
self.overrides.update({
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
for target in _NORMAL_FACTORY
if callable(getattr(torch, target + '_like', None))
})
self.overrides.update({
target: wrap_legacy_constructor(getattr(torch, target), dtype)
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
if callable(getattr(torch, target, None))
})
self.overrides.update({
target: wrap_no_meta_factory(getattr(torch, target))
for target in _NO_META_FACTORY
if callable(getattr(torch, target, None))
})
self.overrides.update(
{
target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like"))
for target in _NORMAL_FACTORY
if callable(getattr(torch, target + "_like", None))
}
)
self.overrides.update(
{
target: wrap_legacy_constructor(getattr(torch, target), dtype)
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
if callable(getattr(torch, target, None))
}
)
self.overrides.update(
{
target: wrap_no_meta_factory(getattr(torch, target))
for target in _NO_META_FACTORY
if callable(getattr(torch, target, None))
}
)
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, wrapper)
......@@ -556,10 +567,9 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod
def distribute(module: nn.Module,
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
def distribute(
module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False
) -> nn.Module:
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
......@@ -574,9 +584,9 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose)
def _apply_to_lazy_module(module: nn.Module,
apply_fn: Callable[[str, torch.Tensor], None],
verbose: bool = False) -> nn.Module:
def _apply_to_lazy_module(
module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False
) -> nn.Module:
if verbose:
# verbose info
param_cnt = 0
......@@ -590,7 +600,7 @@ def _apply_to_lazy_module(module: nn.Module,
if verbose:
param_cnt += 1
total_numel += p.numel()
if getattr(p, '_materialized_data', False) is None:
if getattr(p, "_materialized_data", False) is None:
# if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt += 1
else:
......@@ -612,10 +622,11 @@ def _apply_to_lazy_module(module: nn.Module,
if verbose:
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
_print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}")
_print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}")
_print_rank_0(
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')
f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%"
)
return module
......
from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
__all__ = [
'launch',
'launch_from_openmpi',
'launch_from_slurm',
'launch_from_torch',
'initialize',
"launch",
"launch_from_openmpi",
"launch_from_slurm",
"launch_from_torch",
"initialize",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment