Commit 85762c1a authored by Xiaowei.zhang's avatar Xiaowei.zhang
Browse files

Init the main branch for aiter

parent ae0b3521
Pipeline #3505 canceled with stages
# SPDX-License-Identifier: MIT
from torch import Tensor, Generator
from typing import Optional, Tuple
from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR
from ..utility import dtypes
import torch
@compile_ops("module_mha_fwd", fc_name="mha_fwd")
def mha_fwd(
q: Tensor,
k: Tensor,
v: Tensor,
dropout_p: float,
softmax_scale: float,
is_causal: bool,
window_size_left: int,
window_size_right: int,
return_softmax_lse: bool,
return_dropout_randval: bool,
out: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
gen: Optional[Generator] = None,
): ...
@compile_ops("module_fmha_v3_fwd", fc_name="fmha_v3_fwd")
def fmha_v3_fwd(
q: Tensor,
k: Tensor,
v: Tensor,
dropout_p: float,
softmax_scale: float,
is_causal: bool,
window_size_left: int,
window_size_right: int,
return_softmax_lse: bool,
return_dropout_randval: bool,
out: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
gen: Optional[Generator] = None,
): ...
@compile_ops("module_mha_varlen_fwd", fc_name="mha_varlen_fwd")
def mha_varlen_fwd(
q: Tensor,
k: Tensor,
v: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
logits_soft_cap: float,
zero_tensors: bool,
is_causal: bool,
window_size_left: int,
window_size_right: int,
return_softmax_lse: bool,
return_dropout_randval: bool,
out: Optional[Tensor] = None,
block_table: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
gen: Optional[Generator] = None,
) -> list[Tensor]: ...
@compile_ops("module_mha_bwd", fc_name="mha_bwd")
def mha_bwd(
dout: Tensor,
q: Tensor,
k: Tensor,
v: Tensor,
out: Tensor,
softmax_lse: Tensor,
dropout_p: float,
softmax_scale: float,
is_causal: bool,
window_size_left: int,
window_size_right: int,
deterministic: bool,
dq: Optional[Tensor] = None,
dk: Optional[Tensor] = None,
dv: Optional[Tensor] = None,
dbias: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
): ...
@compile_ops("module_fmha_v3_bwd", fc_name="fmha_v3_bwd")
def fmha_v3_bwd(
dout: Tensor,
q: Tensor,
k: Tensor,
v: Tensor,
out: Tensor,
softmax_lse: Tensor,
dropout_p: float,
softmax_scale: float,
is_causal: bool,
window_size_left: int,
window_size_right: int,
deterministic: bool,
is_v3_atomic_fp32: bool,
how_v3_bf16_cvt: int,
dq: Optional[Tensor] = None,
dk: Optional[Tensor] = None,
dv: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
): ...
@compile_ops("module_mha_varlen_bwd", fc_name="mha_varlen_bwd")
def mha_varlen_bwd(
dout: Tensor,
q: Tensor,
k: Tensor,
v: Tensor,
out: Tensor,
softmax_lse: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
zero_tensors: bool,
is_causal: bool,
window_size_left: int,
window_size_right: int,
deterministic: bool,
dq: Optional[Tensor] = None,
dk: Optional[Tensor] = None,
dv: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
custom_build_args: Optional[dict] = None,
): ...
@compile_ops("module_fmha_v3_varlen_bwd", fc_name="fmha_v3_varlen_bwd")
def fmha_v3_varlen_bwd(
dout: Tensor,
q: Tensor,
k: Tensor,
v: Tensor,
out: Tensor,
softmax_lse: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
zero_tensors: bool,
is_causal: bool,
window_size_left: int,
window_size_right: int,
deterministic: bool,
is_v3_atomic_fp32: bool,
how_v3_bf16_cvt: int,
dq: Optional[Tensor] = None,
dk: Optional[Tensor] = None,
dv: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
): ...
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def _flash_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
bias: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
return_lse: bool,
return_softmax: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
(_, seqlen_q, _, _) = q.shape
# causal=true is the same as causal=false in this case
if seqlen_q == 1 and alibi_slopes is None:
causal = False
md_name = "mha_fwd"
filter = "*"
if q.dtype == dtypes.fp16:
md_name += "_fp16"
filter += "fp16*"
elif q.dtype == dtypes.bf16:
md_name += "_bf16"
filter += "bf16*"
if bias is not None:
md_name += "_bias"
filter += "_bias*"
elif alibi_slopes is not None:
md_name += "_alibi"
filter += "_alibi*"
else:
md_name += "_nbias"
filter += "_nbias*"
if not causal and window_size_left == -1 and window_size_right == -1:
md_name += "_nmask"
filter += "_nmask*"
else:
md_name += "_mask"
filter += "_mask*"
if return_lse:
md_name += "_lse"
filter += "_lse*"
else:
md_name += "_nlse"
filter += "_nlse*"
if dropout_p == 0:
md_name += "_ndropout"
filter += "_ndropout*"
else:
md_name += "_dropout"
filter += "_dropout*"
blob_gen_cmd = [
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd "
"--receipt 100 --filter {} --output_dir {{}}".format(filter),
f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 2 --output_dir {{}}",
]
(_, seqlen_q, nhead_q, hdim_q) = q.shape
(_, seqlen_k, nhead_k, hdim_v) = v.shape
# mask
window_size_left = -1 if window_size_left >= seqlen_k else window_size_left
window_size_right = -1 if window_size_right >= seqlen_k else window_size_right
mask = causal and window_size_left == -1 # causal mask
nmask = not causal and window_size_left == -1 and window_size_right == -1 # no mask
def can_impl_fmha_v3_fwd():
# basic
ret = alibi_slopes is None
ret &= bias is None
ret &= dropout_p == 0.0
ret &= seqlen_q == seqlen_k
ret &= seqlen_q % 256 == 0
ret &= hdim_q == hdim_v
ret &= hdim_q == 128
ret &= nhead_q % nhead_k == 0
ret &= mask or nmask
ret &= return_lse
ret &= "gfx946" in torch.cuda.get_device_properties("cuda").gcnArchName
return ret
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if can_impl_fmha_v3_fwd():
out, softmax_lse, S_dmask, rng_state = fmha_v3_fwd(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
return_lse,
return_softmax,
None,
bias,
alibi_slopes,
None,
)
else:
out, softmax_lse, S_dmask, rng_state = mha_fwd(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
return_lse,
return_softmax,
None,
bias,
alibi_slopes,
None,
custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return out, softmax_lse, S_dmask, rng_state
def _flash_attn_backward(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
dq: Optional[torch.Tensor],
dk: Optional[torch.Tensor],
dv: Optional[torch.Tensor],
dbias: Optional[torch.Tensor],
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
bias: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
) -> torch.Tensor:
md_name = "mha_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
filter2 = "*" # get_bwd_convert_dq_blobs()
filter3 = "*" # get_bwd_dq_dk_dv_blobs()
if q.dtype == dtypes.fp16:
md_name += "_fp16"
filter1 += "fp16*"
filter2 += "fp16*"
filter3 += "fp16*"
elif q.dtype == dtypes.bf16:
md_name += "_bf16"
filter1 += "bf16*"
filter2 += "bf16*"
filter3 += "bf16*"
if bias is not None:
md_name += "_bias"
filter3 += "_bias*"
elif alibi_slopes is not None:
md_name += "_alibi"
filter3 += "_alibi*"
else:
md_name += "_nbias"
filter3 += "_nbias*"
if dbias is not None:
md_name += "_dbias"
filter3 += "_dbias*"
else:
md_name += "_ndbias"
filter3 += "_ndbias*"
if not causal and window_size_left == -1 and window_size_right == -1:
md_name += "_nmask"
filter3 += "_nmask*"
else:
md_name += "_mask"
filter3 += "_mask*"
if dropout_p == 0:
md_name += "_ndropout"
filter3 += "_ndropout*"
else:
md_name += "_dropout"
filter3 += "_dropout*"
if deterministic:
md_name += "_deterministic"
filter2 += "_deterministic*"
filter3 += "_deterministic*"
else:
md_name += "_ndeterministic"
filter2 += "_ndeterministic*"
filter3 += "_ndeterministic*"
filter = f"{filter1}@{filter2}@{filter3}"
blob_gen_cmd = [
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd "
"--receipt 300 --filter {} --output_dir {{}}".format(filter),
f"{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 1 --output_dir {{}}",
]
(_, seqlen_q, nhead_q, hdim_q) = q.shape
(_, seqlen_k, nhead_k, hdim_v) = v.shape
batch_stride_q = q.stride(0)
stride_q = q.stride(1)
nhead_stride_q = q.stride(2)
batch_stride_k = k.stride(0)
stride_k = k.stride(1)
nhead_stride_k = k.stride(2)
batch_stride_v = v.stride(0)
stride_v = v.stride(1)
nhead_stride_v = v.stride(2)
batch_stride_do = dout.stride(0)
stride_do = dout.stride(1)
nhead_stride_do = dout.stride(2)
batch_stride_dk = dk.stride(0)
nhead_stride_dk = dk.stride(2)
batch_stride_dv = dv.stride(0)
nhead_stride_dv = dv.stride(2)
# mask
window_size_left = -1 if window_size_left >= seqlen_k else window_size_left
window_size_right = -1 if window_size_right >= seqlen_k else window_size_right
mask = causal and window_size_left == -1 # causal mask
nmask = not causal and window_size_left == -1 and window_size_right == -1 # no mask
swa = not causal and (window_size_left > 0 or window_size_right > 0)
def np():
# bwd_hd128_bf16_a16_rtne
# bwd_hd128_bf16_a16_rtna
# bwd_hd128_bf16_a16_rtz
# bwd_hd128_bf16_a32_rtne
# bwd_hd128_bf16_a32_rtna
# bwd_hd128_bf16_a32_rtz
# bwd_hd128_bf16_causal_a16_rtne
# bwd_hd128_bf16_causal_a16_rtna
# bwd_hd128_bf16_causal_a16_rtz
# bwd_hd128_bf16_causal_a32_rtne
# bwd_hd128_bf16_causal_a32_rtna
# bwd_hd128_bf16_causal_a32_rtz
# bwd_hd128_fp16_a16
# bwd_hd128_fp16_a32
# bwd_hd128_fp16_causal_a16
# bwd_hd128_fp16_causal_a32
# bwd_hd64_bf16_a16_rtne
# bwd_hd64_bf16_a16_rtna
# bwd_hd64_bf16_a16_rtz
# bwd_hd64_bf16_causal_a16_rtne
# bwd_hd64_bf16_causal_a16_rtna
# bwd_hd64_bf16_causal_a16_rtz
# bwd_hd64_fp16_a16
# bwd_hd64_fp16_causal_a16
npssk = seqlen_q == seqlen_k
npssk &= seqlen_k % 64 == 0
npssk &= stride_q == stride_do
npssk &= nhead_stride_q == nhead_stride_do
npssk &= batch_stride_q == batch_stride_do
npssk &= stride_k == stride_v
npssk &= nhead_stride_k == nhead_stride_v
npssk &= batch_stride_k == batch_stride_v
npssk &= nhead_stride_k == nhead_stride_dk
npssk &= nhead_stride_v == nhead_stride_dv
npssk &= (batch_stride_dk / batch_stride_k) == (nhead_q / nhead_k)
npssk &= (batch_stride_dv / batch_stride_v) == (nhead_q / nhead_k)
hd128_case = (hdim_q == 128) and npssk
hd64_case = (hdim_q == 64 and is_v3_atomic_fp32 == False) and npssk
ret = hd128_case or hd64_case
return ret
def pssk():
# only for hd64 a32 causal/no causal, fp16/bf16-rtne/rtna/rtz cases
# FIXME: Currently we only support mask_type == mask_enum::no_mask or causal mask with seqlen_q == seqlen_k
# Because python side only support mask_enum::bottom_right
# However v3 kernel only support mask_enum::top_left
# bwd_hd64_bf16_a32_rtne_pssk
# bwd_hd64_bf16_a32_rtna_pssk
# bwd_hd64_bf16_a32_rtz_pssk
# bwd_hd64_bf16_causal_a32_rtne_pssk
# bwd_hd64_bf16_causal_a32_rtna_pssk
# bwd_hd64_bf16_causal_a32_rtz_pssk
# bwd_hd64_fp16_a32_pssk
# bwd_hd64_fp16_causal_a32_pssk
ret = (
is_v3_atomic_fp32 == True
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret &= hdim_q == 64
ret &= nmask or (
mask and seqlen_q == seqlen_k
) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
return ret
def pddv():
# only for a16 causal/no causal, fp16/bf16-rtne/rtna/rtz cases
# bwd_hd128_bf16_a16_rtne_pddv
# bwd_hd128_bf16_a16_rtna_pddv
# bwd_hd128_bf16_a16_rtz_pddv
# bwd_hd128_bf16_causal_a16_rtne_pddv
# bwd_hd128_bf16_causal_a16_rtna_pddv
# bwd_hd128_bf16_causal_a16_rtz_pddv
# bwd_hd128_fp16_a16_pddv
# bwd_hd128_fp16_causal_a16_pddv
ret = is_v3_atomic_fp32 == False
ret &= hdim_q > 64 and hdim_q < 128
ret &= seqlen_q == seqlen_k
ret &= seqlen_k % 64 == 0
ret &= stride_q == stride_do
ret &= nhead_stride_q == nhead_stride_do
ret &= batch_stride_q == batch_stride_do
ret &= stride_k == stride_v
ret &= nhead_stride_k == nhead_stride_v
ret &= batch_stride_k == batch_stride_v
ret &= nhead_stride_k == nhead_stride_dk
ret &= nhead_stride_v == nhead_stride_dv
ret &= (batch_stride_dk / batch_stride_k) == (nhead_q / nhead_k)
ret &= (batch_stride_dv / batch_stride_v) == (nhead_q / nhead_k)
return ret
def psskddv():
# only for a32 causal/no causal, fp16/bf16-rtne/rtna/rtz cases
# bwd_hd128_bf16_a32_rtne_psskddv
# bwd_hd128_bf16_a32_rtna_psskddv
# bwd_hd128_bf16_a32_rtz_psskddv
# bwd_hd128_bf16_causal_a32_rtne_psskddv
# bwd_hd128_bf16_causal_a32_rtna_psskddv
# bwd_hd128_bf16_causal_a32_rtz_psskddv
# bwd_hd128_fp16_a32_psskddv
# bwd_hd128_fp16_causal_a32_psskddv
# bwd_hd192_fp16_a32_psskddv
# bwd_hd192_fp16_causal_a32_psskddv
# bwd_hd192_bf16_a32_rtne_psskddv
# bwd_hd192_bf16_a32_rtna_psskddv
# bwd_hd192_bf16_a32_rtz_psskddv
# bwd_hd192_bf16_causal_a32_rtne_psskddv
# bwd_hd192_bf16_causal_a32_rtna_psskddv
# bwd_hd192_bf16_causal_a32_rtz_psskddv
ret = is_v3_atomic_fp32 == True
ret &= hdim_q > 64 and hdim_q <= 192
ret &= (
nmask
or (mask and seqlen_q == seqlen_k)
or (swa and hdim_q > 64 and hdim_q <= 128)
) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
return ret
def can_impl_fmha_v3_bwd():
# basic
ret = alibi_slopes is None
ret &= bias is None
ret &= dbias is None
ret &= dropout_p == 0.0
ret &= deterministic == False
ret &= hdim_q == hdim_v
ret &= nhead_q % nhead_k == 0
ret &= hdim_q >= 64 and hdim_q <= 192 and hdim_q % 8 == 0
ret &= mask or nmask or swa
ret &= np() or pssk() or pddv() or psskddv()
return ret
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
if can_impl_fmha_v3_bwd():
(
dq,
dk,
dv,
softmax_d,
) = fmha_v3_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
deterministic,
is_v3_atomic_fp32,
how_v3_bf16_cvt,
dq,
dk,
dv,
alibi_slopes,
rng_state,
None,
)
else:
(
dq,
dk,
dv,
softmax_d,
) = mha_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
deterministic,
dq,
dk,
dv,
dbias,
bias,
alibi_slopes,
rng_state,
None,
custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return softmax_d
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
bias,
alibi_slopes,
deterministic,
return_lse,
return_softmax,
is_grad_enabled,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
):
is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v])
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_size_q_og = q.size(3)
head_size_v_og = v.size(3)
if head_size_q_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
if head_size_v_og % 8 != 0:
v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8])
out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
q,
k,
v,
dropout_p,
softmax_scale,
causal=causal,
window_size_left=int(window_size[0]),
window_size_right=int(window_size[1]),
bias=bias,
alibi_slopes=alibi_slopes,
return_lse=return_lse,
return_softmax=return_softmax and dropout_p > 0,
)
if is_grad:
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.bias = bias
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
ctx.head_size_q_og = head_size_q_og
ctx.is_v3_atomic_fp32 = is_v3_atomic_fp32
ctx.how_v3_bf16_cvt = how_v3_bf16_cvt
out = out_padded[..., :head_size_v_og]
result = [out]
if return_lse:
result.append(softmax_lse)
if return_softmax:
result.append(S_dmask)
return result[0] if len(result) == 1 else tuple(result)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v)
bias = ctx.bias
dbias = torch.empty_like(bias) if bias is not None else None
head_size_q_og = ctx.head_size_q_og
head_size_v_og = dout.size(3)
dout_padded = dout
if head_size_v_og % 8 != 0:
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_v_og % 8])
_flash_attn_backward(
dout_padded,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
dbias,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
int(ctx.window_size[0]),
int(ctx.window_size[1]),
ctx.bias,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
ctx.is_v3_atomic_fp32,
ctx.how_v3_bf16_cvt,
)
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
dv = dv[..., :head_size_v_og]
return dq, dk, dv, None, None, None, None, dbias, None, None, None, None, None
def flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
bias=None,
alibi_slopes=None,
deterministic=True,
return_lse=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim_q)
k: (batch_size, seqlen, nheads_k, headdim_q)
v: (batch_size, seqlen, nheads_k, headdim_v)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim_q).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
bias: (seqlen_q, seqlen_k)
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim_v).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
bias,
alibi_slopes,
deterministic,
return_lse,
return_attn_probs,
torch.is_grad_enabled(),
)
def _flash_attn_varlen_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
causal: bool,
logits_soft_cap: float = 0.0,
window_size_left: int = -1,
window_size_right: int = -1,
bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
return_lse: bool = False,
return_softmax: bool = False,
block_table: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
zero_tensors: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# causal=true is the same as causal=false in this case
if max_seqlen_q == 1 and alibi_slopes is None:
causal = False
md_name = "mha_varlen_fwd"
if block_table is None:
filter_fwd = "*" # get_fwd_blobs()
if q.dtype == dtypes.fp16:
md_name += "_fp16"
filter_fwd += "fp16*"
elif q.dtype == dtypes.bf16:
md_name += "_bf16"
filter_fwd += "bf16*"
if 0.0 < logits_soft_cap:
md_name += "_logits"
filter_fwd += "_logits*"
else:
md_name += "_nlogits"
filter_fwd += "_nlogits*"
if bias is not None:
md_name += "_bias"
filter_fwd += "_bias*"
elif alibi_slopes is not None:
md_name += "_alibi"
filter_fwd += "_alibi*"
else:
md_name += "_nbias"
filter_fwd += "_nbias*"
if not causal and window_size_left == -1 and window_size_right == -1:
md_name += "_nmask"
filter_fwd += "_nmask*"
else:
md_name += "_mask"
filter_fwd += "_mask*"
if return_lse:
md_name += "_lse"
filter_fwd += "_lse*"
else:
md_name += "_nlse"
filter_fwd += "_nlse*"
if dropout_p == 0:
md_name += "_ndropout"
filter_fwd += "_ndropout*"
else:
md_name += "_dropout"
filter_fwd += "_dropout*"
blob_gen_cmd = [
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd "
"--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd)
]
blob_gen_cmd.append(
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv "
"--receipt 200 --filter {} --output_dir {{}}".format('" @ "')
)
blob_gen_cmd.append(
f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 3 --output_dir {{}}"
)
else:
filter_fwd_splitkv1 = "*" # get_fwd_splitkv_combine_blobs()
filter_fwd_splitkv2 = "*" # get_fwd_splitkv_blobs()
if q.dtype == dtypes.fp16:
md_name += "_fp16"
filter_fwd_splitkv1 += "fp16*"
filter_fwd_splitkv2 += "fp16*"
elif q.dtype == dtypes.bf16:
md_name += "_bf16"
filter_fwd_splitkv1 += "bf16*"
filter_fwd_splitkv2 += "bf16*"
if 0.0 < logits_soft_cap:
md_name += "_logits"
filter_fwd += "_logits*"
else:
md_name += "_nlogits"
filter_fwd += "_nlogits*"
if bias is not None:
md_name += "_bias"
filter_fwd_splitkv2 += "_bias*"
elif alibi_slopes is not None:
md_name += "_alibi"
filter_fwd_splitkv2 += "_alibi*"
else:
md_name += "_nbias"
filter_fwd_splitkv2 += "_nbias*"
if not causal and window_size_left == -1 and window_size_right == -1:
md_name += "_nmask"
filter_fwd_splitkv2 += "_nmask*"
else:
md_name += "_mask"
filter_fwd_splitkv2 += "_mask*"
if return_lse:
md_name += "_lse"
filter_fwd_splitkv1 += "_lse*"
filter_fwd_splitkv2 += "_lse*"
else:
md_name += "_nlse"
filter_fwd_splitkv1 += "_nlse*"
filter_fwd_splitkv2 += "_nlse*"
md_name += "_pagedkv"
filter_fwd_splitkv2 += "_pagedkv*"
filter_fwd_splitkv = f"{filter_fwd_splitkv1}@{filter_fwd_splitkv2}"
blob_gen_cmd = [
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd "
"--receipt 200 --filter {} --output_dir {{}}".format('" "')
]
blob_gen_cmd.append(
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv "
"--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd_splitkv)
)
blob_gen_cmd.append(
f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 3 --output_dir {{}}"
)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = mha_varlen_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
logits_soft_cap,
zero_tensors,
causal,
window_size_left,
window_size_right,
return_lse,
return_softmax,
out,
block_table,
bias,
alibi_slopes,
None,
custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return out, softmax_lse, S_dmask, rng_state
def _flash_attn_varlen_backward(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
dq: Optional[torch.Tensor],
dk: Optional[torch.Tensor],
dv: Optional[torch.Tensor],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
alibi_slopes: Optional[torch.Tensor],
deterministic: bool,
rng_state: Optional[torch.Tensor] = None,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
zero_tensors: bool = False,
) -> torch.Tensor:
md_name = "mha_varlen_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
filter2 = "*" # get_bwd_convert_dq_blobs()
filter3 = "*" # get_bwd_dq_dk_dv_blobs()
if q.dtype == dtypes.fp16:
md_name += "_fp16"
filter1 += "fp16*"
filter2 += "fp16*"
filter3 += "fp16*"
elif q.dtype == dtypes.bf16:
md_name += "_bf16"
filter1 += "bf16*"
filter2 += "bf16*"
filter3 += "bf16*"
if alibi_slopes is None:
md_name += "_nbias"
filter3 += "_nbias*"
else:
md_name += "_alibi"
filter3 += "_alibi*"
if not causal and window_size_left == -1 and window_size_right == -1:
md_name += "_nmask"
filter3 += "_nmask*"
else:
md_name += "_mask"
filter3 += "_mask*"
if dropout_p == 0:
md_name += "_ndropout"
filter3 += "_ndropout*"
else:
md_name += "_dropout"
filter3 += "_dropout*"
if deterministic:
md_name += "_deterministic"
filter2 += "_deterministic*"
filter3 += "_deterministic*"
else:
md_name += "_ndeterministic"
filter2 += "_ndeterministic*"
filter3 += "_ndeterministic*"
filter = f"{filter1}@{filter2}@{filter3}"
blob_gen_cmd = [
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d bwd "
"--receipt 400 --filter {} --output_dir {{}}".format(filter),
f"{AITER_CSRC_DIR}/cpp_itfs/mha_bwd_generate.py --receipt 1 --output_dir {{}}",
]
(_, nhead_q, hdim_q) = q.shape
nhead_k = v.shape[-2]
hdim_v = v.shape[-1]
# mask
window_size_left = -1 if window_size_left >= max_seqlen_k else window_size_left
window_size_right = -1 if window_size_right >= max_seqlen_k else window_size_right
mask = causal == True and window_size_left == -1 # causal mask
nmask = (
causal == False and window_size_left == -1 and window_size_right == -1
) # no mask
def pssk():
# only for hd64 a32 causal/no causal, fp16/bf16-rtne/rtna/rtz cases
# FIXME: Currently we only support mask_type == mask_enum::no_mask
# Because python side only support mask_enum::bottom_right
# However v3 kernel only support mask_enum::top_left
# bwd_hd64_bf16_a32_rtne_pssk_group
# bwd_hd64_bf16_a32_rtna_pssk_group
# bwd_hd64_bf16_a32_rtz_pssk_group
# bwd_hd64_bf16_causal_a32_rtne_pssk_group
# bwd_hd64_bf16_causal_a32_rtna_pssk_group
# bwd_hd64_bf16_causal_a32_rtz_pssk_group
# bwd_hd64_fp16_a32_pssk_group
# bwd_hd64_fp16_causal_a32_pssk_group
# bwd_hd128_bf16_a32_rtne_pssk_group
# bwd_hd128_bf16_a32_rtna_pssk_group
# bwd_hd128_bf16_a32_rtz_pssk_group
# bwd_hd128_bf16_causal_a32_rtne_pssk_group
# bwd_hd128_bf16_causal_a32_rtna_pssk_group
# bwd_hd128_bf16_causal_a32_rtz_pssk_group
# bwd_hd128_fp16_a32_pssk_group
# bwd_hd128_fp16_causal_a32_pssk_group
ret = (
is_v3_atomic_fp32 == True
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret &= hdim_q == 64 or hdim_q == 128
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
return ret
def psskddv():
# bwd_hd128_bf16_a32_rtne_psskddv_group
# bwd_hd128_bf16_a32_rtna_psskddv_group
# bwd_hd128_bf16_a32_rtz_psskddv_group
# bwd_hd128_bf16_causal_a32_rtne_psskddv_group
# bwd_hd128_bf16_causal_a32_rtna_psskddv_group
# bwd_hd128_bf16_causal_a32_rtz_psskddv_group
# bwd_hd128_fp16_a32_psskddv_group
# bwd_hd128_fp16_causal_a32_psskddv_group
ret = (
is_v3_atomic_fp32 == True
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
ret &= hdim_q > 64 and hdim_q < 128
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
return ret
def can_impl_fmha_v3_bwd():
# basic
ret = alibi_slopes is None
# ret &= bias is None
# ret &= dbias is None
ret &= dropout_p == 0.0
ret &= deterministic == False
ret &= hdim_q == hdim_v
ret &= nhead_q % nhead_k == 0
ret &= hdim_q >= 64 and hdim_q <= 128 and hdim_q % 8 == 0
ret &= mask or nmask
ret &= pssk() or psskddv()
return ret
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
if can_impl_fmha_v3_bwd():
(
dq,
dk,
dv,
softmax_d,
) = fmha_v3_varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors,
causal,
window_size_left,
window_size_right,
deterministic,
is_v3_atomic_fp32,
how_v3_bf16_cvt,
dq,
dk,
dv,
alibi_slopes,
rng_state,
None,
)
else:
(
dq,
dk,
dv,
softmax_d,
) = mha_varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
zero_tensors,
causal,
window_size_left,
window_size_right,
deterministic,
dq,
dk,
dv,
alibi_slopes,
rng_state,
None,
custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return softmax_d
class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
logits_soft_cap,
causal,
window_size,
bias,
alibi_slopes,
deterministic,
return_lse,
return_softmax,
block_table,
out,
is_grad_enabled,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
):
is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v])
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_size_q_og = q.size(-1)
head_size_v_og = v.size(-1)
if head_size_q_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
if head_size_v_og % 8 != 0:
v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8])
out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
logits_soft_cap=logits_soft_cap,
window_size_left=window_size[0],
window_size_right=window_size[1],
bias=bias,
alibi_slopes=alibi_slopes,
return_lse=return_lse,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
out=out,
)
if is_grad:
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.bias = bias
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
ctx.head_size_q_og = head_size_q_og
ctx.is_v3_atomic_fp32 = is_v3_atomic_fp32
ctx.how_v3_bf16_cvt = how_v3_bf16_cvt
out = out_padded[..., :head_size_v_og]
result = [out]
if return_lse:
result.append(softmax_lse)
if return_softmax:
result.append(S_dmask)
return result[0] if len(result) == 1 else tuple(result)
@staticmethod
def backward(ctx, dout, *args):
(
q,
k,
v,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
rng_state,
) = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
bias = ctx.bias
dbias = torch.empty_like(bias) if bias is not None else None
head_size_q_og = ctx.head_size_q_og
head_size_v_og = dout.size(2)
dout_padded = dout
if head_size_v_og % 8 != 0:
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_v_og % 8])
# TODO - dbias
_flash_attn_varlen_backward(
dout_padded,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
is_v3_atomic_fp32=ctx.is_v3_atomic_fp32,
how_v3_bf16_cvt=ctx.how_v3_bf16_cvt,
)
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
dv = dv[..., :head_size_v_og]
return (
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
dbias,
None,
None,
None,
None,
None,
None,
None,
)
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
logits_soft_cap=0.0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
bias=None,
alibi_slopes=None,
deterministic=False,
return_lse=False,
return_attn_probs=False,
block_table=None,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim_q), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim_q), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim_v), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype dtypes.i32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype dtypes.i32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim_q).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
bias: (seqlen_q, seqlen_k)
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim_v).
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenFunc.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
logits_soft_cap,
causal,
window_size,
bias,
alibi_slopes,
deterministic,
return_lse,
return_attn_probs,
block_table,
out,
torch.is_grad_enabled(),
)
@compile_ops("module_mha_batch_prefill", fc_name="mha_batch_prefill")
def mha_batch_prefill(
q: Tensor,
k: Tensor,
v: Tensor,
cu_seqlens_q: Tensor,
kv_indptr: Tensor,
kv_page_indices: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
logits_soft_cap: float,
zero_tensors: bool,
is_causal: bool,
window_size_left: int,
window_size_right: int,
return_softmax_lse: bool,
return_dropout_randval: bool,
out: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
gen: Optional[Generator] = None,
): ...
def _mha_batch_prefill(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
kv_indptr: torch.Tensor,
kv_page_indices: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
softmax_scale: float,
causal: bool,
logits_soft_cap: float = 0.0,
window_size_left: int = -1,
window_size_right: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
return_lse: bool = False,
return_softmax: bool = False,
zero_tensors: bool = False,
out: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# causal=true is the same as causal=false in this case
if max_seqlen_q == 1 and alibi_slopes is None:
causal = False
md_name = "mha_batch_prefill"
filter_fwd = "*" # get_fwd_blobs()
if q.dtype == torch.float16:
md_name += "_fp16"
filter_fwd += "fp16*"
elif q.dtype == torch.bfloat16:
md_name += "_bf16"
filter_fwd += "bf16*"
if 0.0 < logits_soft_cap:
md_name += "_logits"
filter_fwd += "_logits*"
else:
md_name += "_nlogits"
filter_fwd += "_nlogits*"
if alibi_slopes is None:
md_name += "_nbias"
filter_fwd += "_nbias*"
else:
md_name += "_alibi"
filter_fwd += "_alibi*"
if not causal and window_size_left == -1 and window_size_right == -1:
md_name += "_nmask"
filter_fwd += "_nmask*"
else:
md_name += "_mask"
filter_fwd += "_mask*"
if return_lse:
md_name += "_lse"
filter_fwd += "_lse*"
else:
md_name += "_nlse"
filter_fwd += "_nlse*"
if dropout_p == 0:
md_name += "_ndropout"
filter_fwd += "_ndropout*"
else:
md_name += "_dropout"
filter_fwd += "_dropout*"
blob_gen_cmd = [
f"{CK_DIR}/example/ck_tile/01_fmha/generate.py -d batch_prefill "
"--receipt 200 --filter {} --output_dir {{}}".format(filter_fwd)
]
blob_gen_cmd.append(
f"{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_generate.py --receipt 4 --output_dir {{}}"
)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = mha_batch_prefill(
q,
k,
v,
cu_seqlens_q,
kv_indptr,
kv_page_indices,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
logits_soft_cap,
zero_tensors,
causal,
window_size_left,
window_size_right,
return_lse,
return_softmax,
out,
alibi_slopes,
None,
custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return out, softmax_lse, S_dmask, rng_state
def mha_batch_prefill_func(
q,
k,
v,
cu_seqlens_q,
kv_indptr,
kv_page_indices,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
logits_soft_cap=0.0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_lse=False,
return_attn_probs=False,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_size_q_og = q.size(2)
head_size_v_og = v.size(2)
if head_size_q_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
if head_size_v_og % 8 != 0:
v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8])
out_padded, softmax_lse, S_dmask, rng_state = _mha_batch_prefill(
q,
k,
v,
cu_seqlens_q,
kv_indptr,
kv_page_indices,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
logits_soft_cap=logits_soft_cap,
window_size_left=window_size[0],
window_size_right=window_size[1],
alibi_slopes=alibi_slopes,
return_lse=return_lse,
return_softmax=return_attn_probs and dropout_p > 0,
out=out,
)
out = out_padded[..., :head_size_v_og]
result = [out]
if return_lse:
result.append(softmax_lse)
if return_attn_probs:
result.append(S_dmask)
return result[0] if len(result) == 1 else tuple(result)
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
import torch
from torch import Tensor
from typing import Optional,List
from ..jit.core import (
compile_ops,
)
from .enum import ActivationType, Enum, QuantType
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w8a8(
input: torch.Tensor,
b_qweight : torch.Tensor,
output : torch.Tensor,
a_scale: torch.Tensor,
b_scale : torch.Tensor,
topk_weights : Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids : torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int)-> torch.Tensor:
"""
---------------------------------------------------------------
# MoE 场景下 8bit 量化的 GEMM 计算(Marlin 优化版)
## 关键前置条件
必须配合对应的权重 Shuffle 函数使用,否则会导致计算结果完全错误:
- GEMM1 场景:使用 ops.marlin_weights 处理权重
- GEMM2 场景:使用 ops.marlin_weights_ours 处理权重
---------------------------------------------------------------
"""
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w4a8(
input: torch.Tensor,
b_qweight : torch.Tensor,
output : torch.Tensor,
a_scale: torch.Tensor,
b_scale : torch.Tensor,
topk_weights : Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids : torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int)-> torch.Tensor:
"""
---------------------------------------------------------------
# MoE 场景下 8bit 量化的 GEMM 计算(Marlin 优化版)
## 关键前置条件
必须配合对应的权重 Shuffle 函数使用,否则会导致计算结果完全错误:
- GEMM1 场景:使用 ops.marlin_weights 处理权重
- GEMM2 场景:使用 ops.marlin_weights_ours 处理权重
---------------------------------------------------------------
"""
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w8a8_fp8(
input: torch.Tensor,
b_qweight : torch.Tensor,
output : torch.Tensor,
a_scale: torch.Tensor,
b_scale : torch.Tensor,
topk_weights : Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids : torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int)-> torch.Tensor:
"""
---------------------------------------------------------------
# MoE 场景下 8bit 量化的 GEMM 计算(Marlin 优化版)
## 关键前置条件
必须配合对应的权重 Shuffle 函数使用,否则会导致计算结果完全错误:
- GEMM1 场景:使用 ops.marlin_weights 处理权重
- GEMM2 场景:使用 ops.marlin_weights_ours 处理权重
---------------------------------------------------------------
"""
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_gemm_marlin_w4a16(
input: torch.Tensor,
b_qweight : torch.Tensor,
output : torch.Tensor,
b_scale: torch.Tensor,
b_zeros : torch.Tensor,
topk_weights : Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids : torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k : int,
mode :int,
delta: int)-> torch.Tensor:
"""
---------------------------------------------------------------
# MoE 场景下 4bit 量化的 GEMM 计算(Marlin 优化版)
## 关键前置条件
必须配合对应的权重 Shuffle 函数使用,否则会导致计算结果完全错误:
---------------------------------------------------------------
"""
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_w8a8_gemm_block_wise(
input: torch.Tensor,
a_scales: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
group_size_n: int,
group_size_k: int,
top_k: int,
BLOCK_SIZE_m: int,
BLOCK_SIZE_n: int,
BLOCK_SIZE_k: int,
kloops: int,
nloops: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_w8a8_gemm_block_wise_kernel2(
input: torch.Tensor,
a_scales: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
group_size_n: int,
group_size_k: int,
top_k: int,
BLOCK_SIZE_m: int,
BLOCK_SIZE_n: int,
BLOCK_SIZE_k: int,
kloops: int,
nloops: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_w8a8_gemm_block_wise_fp8(
input: torch.Tensor,
a_scales: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
group_size_n: int,
group_size_k: int,
top_k: int,
BLOCK_SIZE_m: int,
BLOCK_SIZE_n: int,
BLOCK_SIZE_k: int,
kloops: int,
nloops: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_w8a8_gemm_block_wise_kernel2_fp8(
input: torch.Tensor,
a_scales: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
group_size_n: int,
group_size_k: int,
top_k: int,
BLOCK_SIZE_m: int,
BLOCK_SIZE_n: int,
BLOCK_SIZE_k: int,
kloops: int,
nloops: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_w8a16_gemm_awq(
input: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k: int,
BLOCK_SIZE_m: int,
BLOCK_SIZE_n: int,
BLOCK_SIZE_k: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_w8a16_gemm_block_wise(
input: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
group_size_n: int,
group_size_k: int,
top_k: int,
BLOCK_SIZE_m: int,
BLOCK_SIZE_n: int,
BLOCK_SIZE_k: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_wna16_gemm_base(
input: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k: int,
BLOCK_SIZE_M: int,
BLOCK_SIZE_N: int,
BLOCK_SIZE_K: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_wna16_gemm(
input: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k: int,
BLOCK_SIZE_m: int,
BLOCK_SIZE_n: int,
BLOCK_SIZE_k: int,
kloops: int,
nloops: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_wna16_gemm_2(
input: torch.Tensor,
output: torch.Tensor,
b_qweight: torch.Tensor,
b_scales: torch.Tensor,
b_qzeros: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
top_k: int,
BLOCK_SIZE_m: int,
BLOCK_SIZE_n: int,
BLOCK_SIZE_k: int,
kloops: int,
nloops: int,
bit: int
) -> torch.Tensor:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_topk_softmax(
topk_weights: torch.Tensor, # 移除 C++ 引用 &
topk_indices: torch.Tensor, # 移除 C++ 引用 &
token_expert_indices: torch.Tensor, # 移除 C++ 引用 &
gating_output: torch.Tensor # 移除 C++ 引用 &
) -> None: # 替代 -> None (C++ 中的 void)
pass
@compile_ops("module_moe_c_kernel")
def moe_c_silu_and_mul( out : torch.Tensor,
input : torch.Tensor) -> None:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_sum(
input: torch.Tensor, # 移除 C++ 引用 &
output: torch.Tensor, # 移除 C++ 引用 &
topk_ids: torch.Tensor
) -> None:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_moe_align_block_size(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor
) -> None:
pass
@compile_ops("module_moe_c_kernel")
def moe_c_sgl_moe_align_block_size(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor
) -> None:
pass
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from typing import Optional,List
from ..jit.core import (
compile_ops,
)
from .enum import ActivationType, Enum, QuantType
@compile_ops("module_moe_utils")
def topk_softmax(
topk_weights: Tensor,
topk_indices: Tensor,
token_expert_indices: Tensor,
gating_output: Tensor,
need_renorm: bool,
) -> None: ...
@compile_ops("module_moe_utils")
def moe_sum(input: Tensor, output: Tensor)->None: ...
@compile_ops("module_moe_sum")
def asm_moe_sum(input: Tensor, output: Tensor, sorted_ids: Tensor)->None: ...
@compile_ops("module_moe_utils")
def sgl_moe_align_block_size(topk_ids: Tensor, num_experts: int,
block_size: int, sorted_token_ids: Tensor,
experts_ids: Tensor,
num_tokens_post_pad: Tensor) -> None: ...
@compile_ops("module_moe_utils")
def moe_align_block_size(
topk_ids: Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: Tensor,
experts_ids: Tensor,
num_tokens_post_pad: Tensor,
) -> None: ...
@compile_ops("module_moe_asm")
def asm_fmoe_stage1(
out: Tensor,
input: Tensor,
gate: Tensor,
down: Tensor,
sorted_token_ids: Tensor,
sorted_weights: Tensor,
sorted_expert_ids: Tensor,
num_valid_ids: Tensor,
top_k: int,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None,
zero_points: Optional[torch.Tensor] = None,
mode: Optional[int] = 0,
solidx: Optional[int] = 0,
block_size: Optional[int] = 16,
persist_groups: Optional[int] = 0,
) -> None: ...
@compile_ops("module_moe_asm")
def asm_fmoe_stage2(
out: Tensor,
input: Tensor,
gate: Tensor,
down: Tensor,
sorted_token_ids: Tensor,
sorted_weights: Tensor,
sorted_expert_ids: Tensor,
num_valid_ids: Tensor,
top_k: int,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None,
zero_points: Optional[torch.Tensor] = None,
mode: Optional[int] = 0,
solidx: Optional[int] = 0,
block_size: Optional[int] = 16,
persist_groups: Optional[int] = 0,
)-> None: ...
@compile_ops("module_moe_asm")
def asm_fmoe_a8(
out: Tensor,
input: Tensor,
gate: Tensor,
down: Tensor,
sorted_token_ids: Tensor,
sorted_weights: Tensor,
sorted_expert_ids: Tensor,
num_valid_ids: Tensor,
top_k: int,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None,
zero_points: Optional[torch.Tensor] = None,
mode: Optional[int] = 0,
solidx: Optional[int] = 0,
out_type:Optional[int] = 0,
persist_groups:Optional[int] = 0,
use_shuffle:Optional[int] = 0,
)-> None: ...
@compile_ops("module_moe_asm")
def asm_moe_get_solutions(
hidden_states: Tensor,
w1: Tensor,
w2: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
use_int8_w8a16: Optional[bool] = False,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False,
use_int4_w4a8: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False,
per_channel_quant: Optional[bool] = False,
w1_zp: Optional[Tensor] = None,
w2_zp: Optional[Tensor] = None,
w1_scale: Optional[Tensor] = None,
w2_scale: Optional[Tensor] = None,
a1_scale: Optional[Tensor] = None,
a2_scale: Optional[Tensor] = None,
block_shape_n: Optional[int] = 0,
block_shape_k: Optional[int] = 0,
block_m: Optional[int] = 32,
expert_mask: Optional[Tensor] = None,
) -> list[str]: ...
# @compile_ops("module_moe_asm")
# def fmoe(
# out: Tensor,
# input: Tensor,
# gate: Tensor,
# down: Tensor,
# sorted_token_ids: Tensor,
# sorted_weights: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# topk: int,
# ): ...
# @compile_ops("module_moe_asm")
# def fmoe_int8_g1u0(
# out: Tensor,
# input: Tensor,
# gate: Tensor,
# down: Tensor,
# sorted_token_ids: Tensor,
# sorted_weights: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# topk: int,
# input_scale: Tensor,
# fc1_scale: Tensor,
# fc2_scale: Tensor,
# fc2_smooth_scale: Tensor,
# activation: Optional[Enum] = ActivationType.Silu,
# ): ...
# @compile_ops("module_moe_asm")
# def fmoe_g1u1(
# out: Tensor,
# input: Tensor,
# gate: Tensor,
# down: Tensor,
# sorted_token_ids: Tensor,
# sorted_weights: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# topk: int,
# input_scale: Tensor,
# fc1_scale: Tensor,
# fc2_scale: Tensor,
# fc2_smooth_scale: Optional[Tensor] = None,
# activation: Optional[Enum] = ActivationType.Silu,
# ): ...
# @compile_ops("module_moe_asm")
# def fmoe_g1u1_tkw1(
# out: Tensor,
# input: Tensor,
# gate: Tensor,
# down: Tensor,
# sorted_token_ids: Tensor,
# sorted_weights: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# topk: int,
# input_scale: Tensor,
# fc1_scale: Tensor,
# fc2_scale: Tensor,
# fc2_smooth_scale: Optional[Tensor] = None,
# activation: Optional[Enum] = ActivationType.Silu,
# ): ...
# @compile_ops("module_moe_asm")
# def fmoe_int8_g1u0_a16(
# out: Tensor,
# input: Tensor, # bf16
# gate: Tensor,
# down: Tensor,
# sorted_token_ids: Tensor,
# sorted_weights: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# topk: int,
# fc1_scale: Tensor,
# fc2_scale: Tensor,
# fc1_smooth_scale: Tensor,
# fc2_smooth_scale: Tensor,
# ): ...
# @compile_ops("module_moe_asm")
# def fmoe_g1u1_a16(
# out: Tensor,
# input: Tensor, # bf16
# gate: Tensor,
# down: Tensor,
# sorted_token_ids: Tensor,
# sorted_weights: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# topk: int,
# fc1_scale: Tensor,
# fc2_scale: Tensor,
# fc1_smooth_scale: Tensor,
# fc2_smooth_scale: Tensor,
# ): ...
# @compile_ops("module_moe_asm")
# def fmoe_fp8_blockscale_g1u1(
# out: Tensor,
# input: Tensor,
# gate: Tensor,
# down: Tensor,
# sorted_token_ids: Tensor,
# sorted_weights: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# topk: int,
# input_scale: Tensor,
# fc1_scale: Tensor,
# fc2_scale: Tensor,
# fc_scale_blkn: int = 128,
# fc_scale_blkk: int = 128,
# fc2_smooth_scale: Optional[Tensor] = None,
# activation: ActivationType = ActivationType.Silu,
# ): ...
# @compile_ops("module_moe_asm")
# def moe_stage1_g1u1(
# input: torch.Tensor,
# w1: torch.Tensor,
# w2: torch.Tensor,
# sorted_token_ids: torch.Tensor,
# sorted_expert_ids: torch.Tensor,
# num_valid_ids: torch.Tensor,
# out: torch.Tensor,
# inter_dim: int,
# kernelName: str,
# block_m: int,
# ksplit: int = 0,
# activation: ActivationType = ActivationType.Silu,
# quant_type: QuantType = QuantType.No,
# a1_scale: Optional[torch.Tensor] = None,
# w1_scale: Optional[torch.Tensor] = None,
# sorted_weights: Optional[torch.Tensor] = None,
# ) -> None: ...
@compile_ops("module_moe")
def ck_moe(
hidden_states: Tensor,
w1: Tensor,
w2: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
use_int8_w8a16: Optional[bool] = False,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a8_block: Optional[bool] = False,
use_int4_w4a8_block: Optional[bool] = False,
w1_zp: Optional[Tensor] = None,
w2_zp: Optional[Tensor] = None,
w1_scale: Optional[Tensor] = None,
w2_scale: Optional[Tensor] = None,
a1_scale: Optional[Tensor] = None,
a2_scale: Optional[Tensor] = None,
block_shape_n: Optional[int] = 0,
block_shape_k: Optional[int] = 0,
block_m: Optional[int] = 32,
solution_id: Optional[int] = 0,
expert_mask: Optional[Tensor] = None,
)-> torch.Tensor: ...
@compile_ops("module_moe")
def ck_shuffle_moe(
hidden_states: Tensor,
w1: Tensor,
w2: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
use_int8_w8a16: Optional[bool] = False,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a8_block: Optional[bool] = False,
use_int4_w4a8_block: Optional[bool] = False,
w1_zp: Optional[Tensor] = None,
w2_zp: Optional[Tensor] = None,
w1_scale: Optional[Tensor] = None,
w2_scale: Optional[Tensor] = None,
a1_scale: Optional[Tensor] = None,
a2_scale: Optional[Tensor] = None,
block_shape_n: Optional[int] = 0,
block_shape_k: Optional[int] = 0,
block_m: Optional[int] = 32,
solution_id: Optional[int] = 0,
expert_mask: Optional[Tensor] = None,
)-> torch.Tensor: ...
@compile_ops("module_moe")
def ck_moe_get_solutions(
hidden_states: Tensor,
w1: Tensor,
w2: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
use_int8_w8a16: Optional[bool] = False,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a8_block: Optional[bool] = False,
use_int4_w4a8_block: Optional[bool] = False,
w1_zp: Optional[Tensor] = None,
w2_zp: Optional[Tensor] = None,
w1_scale: Optional[Tensor] = None,
w2_scale: Optional[Tensor] = None,
a1_scale: Optional[Tensor] = None,
a2_scale: Optional[Tensor] = None,
block_shape_n: Optional[int] = 0,
block_shape_k: Optional[int] = 0,
block_m: Optional[int] = 32,
expert_mask: Optional[Tensor] = None,
) -> list[int]: ...
@compile_ops("module_moe")
def ck_moe_stage_1(
hidden_states: Tensor,
w1: Tensor,
w2: Tensor,
sorted_token_ids: Tensor,
sorted_expert_ids: Tensor,
tokens_positions_per_expert: Tensor,
num_valid_ids: Tensor,
out: Tensor,
topk: int,
use_int8_w8a8_block: Optional[bool] = False,
use_fp8_w8a8_block: Optional[bool] = False,
w1_scale: Optional[Tensor] = None,
a1_scale: Optional[Tensor] = None,
block_shape_n: Optional[int] = 0,
block_shape_k: Optional[int] = 0,
block_m: Optional[int] = 32,
sorted_weights: Optional[Tensor] = None,
act_op: Optional[int] = 0,
)->None: ...
@compile_ops("module_moe")
def ck_moe_stage_2(
inter_states: Tensor, # the output of stage 1
w1: Tensor,
w2: Tensor,
sorted_token_ids: Tensor,
sorted_expert_ids: Tensor,
tokens_positions_per_expert: Tensor,
num_valid_ids: Tensor,
out: Tensor,
topk: int,
use_int8_w8a8_block: Optional[bool] = False,
use_fp8_w8a8_block: Optional[bool] = False,
w2_scale: Optional[Tensor] = None,
a2_scale: Optional[Tensor] = None,
block_shape_n: Optional[int] = 0,
block_shape_k: Optional[int] = 0,
block_m: Optional[int] = 32,
sorted_weights: Optional[Tensor] = None,
)->None: ...
@compile_ops("module_moe")
def ck_moe_per_token_quant(
input: Tensor,
out_quant: Tensor,
out_scale: Tensor,
)->None: ...
# @compile_ops("module_moe_ck2stages")
# def ck_moe_stage1(
# hidden_states: Tensor,
# w1: Tensor,
# w2: Tensor,
# sorted_token_ids: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# out: Tensor,
# topk: int,
# w1_scale: Optional[Tensor] = None,
# a1_scale: Optional[Tensor] = None,
# block_m: Optional[int] = 32,
# sorted_weights: Optional[Tensor] = None,
# act_op: Optional[int] = 0,
# ): ...
# @compile_ops("module_moe_ck2stages")
# def ck_moe_stage2(
# inter_states: Tensor,
# w1: Tensor,
# w2: Tensor,
# sorted_token_ids: Tensor,
# sorted_expert_ids: Tensor,
# num_valid_ids: Tensor,
# out: Tensor,
# topk: int,
# w2_scale: Optional[Tensor] = None,
# a2_scale: Optional[Tensor] = None,
# block_m: Optional[int] = 32,
# sorted_weights: Optional[Tensor] = None,
# ): ...
# SPDX-License-Identifier: MIT
import torch
from typing import Optional
from ..jit.core import compile_ops
MD_NAME = "module_moe_sorting"
@compile_ops("module_moe_sorting")
def moe_sorting_fwd(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
sorted_weights: torch.Tensor,
sorted_expert_ids: torch.Tensor,
tokens_positions_per_expert: torch.Tensor,
num_valid_ids: torch.Tensor,
moe_buf: torch.Tensor,
num_experts: int,
unit_size: int,
local_expert_mask: Optional[torch.Tensor] = None,
) ->None: ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from typing import Optional
from ..jit.core import compile_ops
MD_NAME = "module_norm"
def gen_layer_norm_fake_tensors(
input: Tensor,
# normalized_shape: List[int],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
x_bias: Optional[Tensor] = None,
) -> Tensor:
return torch.empty_like(
input,
dtype=input.dtype,
device=input.device,
)
@compile_ops(
"module_norm", fc_name="layernorm2d_fwd", gen_fake=gen_layer_norm_fake_tensors
)
def layer_norm(
input: Tensor,
# normalized_shape: List[int],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
epsilon: float = 1e-5,
x_bias: Optional[Tensor] = None,
) -> Tensor: ...
@compile_ops(
"module_norm", fc_name="layernorm2d_fwd", gen_fake=gen_layer_norm_fake_tensors
)
def layernorm2d_fwd(
input: Tensor,
# normalized_shape: List[int],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
epsilon: float = 1e-5,
x_bias: Optional[Tensor] = None,
) -> Tensor: ...
@compile_ops("module_norm")
def layernorm2d_fwd_with_add(
out: Tensor,
input: Tensor,
residual_in: Tensor,
residual_out: Tensor,
weight: Tensor,
bias: Tensor,
epsilon: float,
x_bias: Optional[Tensor] = None,
) -> None: ...
@compile_ops("module_norm")
def layernorm2d_fwd_with_smoothquant(
out: Tensor,
input: Tensor,
xscale: Tensor,
yscale: Tensor,
weight: Tensor,
bias: Tensor,
epsilon: float,
x_bias: Optional[Tensor] = None,
) -> None: ...
@compile_ops("module_norm")
def layernorm2d_fwd_with_add_smoothquant(
out: Tensor,
input: Tensor,
residual_in: Tensor,
residual_out: Tensor,
xscale: Tensor,
yscale: Tensor,
weight: Tensor,
bias: Tensor,
epsilon: float,
x_bias: Optional[Tensor] = None,
) -> None: ...
@compile_ops("module_norm")
def layernorm2d_fwd_with_dynamicquant(
out: Tensor,
input: Tensor,
yscale: Tensor,
weight: Tensor,
bias: Tensor,
epsilon: float,
x_bias: Optional[Tensor] = None,
) -> None: ...
@compile_ops("module_norm")
def layernorm2d_fwd_with_add_dynamicquant(
out: Tensor,
input: Tensor,
residual_in: Tensor,
residual_out: Tensor,
yscale: Tensor,
weight: Tensor,
bias: Tensor,
epsilon: float,
x_bias: Optional[Tensor] = None,
) -> None: ...
# @compile_ops("module_norm")
# def layernorm2d_with_add_asm(
# out: Tensor,
# input: Tensor,
# residual_in: Tensor,
# residual_out: Tensor,
# weight: Tensor,
# bias: Tensor,
# epsilon: float,
# x_bias: Optional[Tensor] = None,
# ): ...
# @compile_ops("module_norm")
# def layernorm2d_with_add_smoothquant_asm(
# out: Tensor,
# input: Tensor,
# residual_in: Tensor,
# residual_out: Tensor,
# xscale: Tensor,
# yscale: Tensor,
# weight: Tensor,
# bias: Tensor,
# epsilon: float,
# x_bias: Optional[Tensor] = None,
# ): ...
# SPDX-License-Identifier: MIT
from torch import Tensor
from ..jit.core import compile_ops
MD_NAME = "module_pos_encoding"
@compile_ops("module_pos_encoding")
def rotary_embedding_fwd(
positions: Tensor,
query: Tensor,
key: Tensor,
head_size: int,
cos_cache: Tensor,
sin_cache: Tensor,
is_neox: bool,
is_nope_first: bool,
) -> None: ...
@compile_ops("module_pos_encoding")
def batched_rotary_embedding(
positions: Tensor,
query: Tensor,
key: Tensor,
head_size: int,
cos_cache: Tensor,
sin_cache: Tensor,
is_neox: bool,
is_nope_first: bool,
rot_dim: int,
cos_sin_cache_offsets: Tensor,
) -> None: ...
\ No newline at end of file
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from typing import Optional
from ..jit.core import compile_ops
import torch.nn.functional as F
import functools
from .enum import QuantType, ActivationType
from . import triton
from ..utility import dtypes, fp4_utils
@compile_ops("module_smoothquant")
def smoothquant_fwd(
out: Tensor, input: Tensor, x_scale: Tensor, y_scale: Tensor
) -> None: ...
@compile_ops("module_smoothquant")
def moe_smoothquant_fwd(
out: Tensor, input: Tensor, x_scale: Tensor, topk_ids: Tensor, y_scale: Tensor
) -> None: ...
# following are pure torch implement
@functools.lru_cache()
def get_dtype_max(dtype):
try:
dtypeMax = torch.finfo(dtype).max
except:
dtypeMax = torch.iinfo(dtype).max
return dtypeMax
def pertoken_quant(
x,
scale=None,
x_scale=None, # smooth_scale
scale_dtype=dtypes.fp32,
quant_dtype=dtypes.i8,
dtypeMax=None,
):
x = x.to(dtypes.fp32)
if x_scale is None:
hidden_states = x
else:
# smooth quant
hidden_states = x * x_scale
if dtypeMax is None:
dtypeMax = get_dtype_max(quant_dtype)
per_token_scale = scale
if scale is None:
# [m, 1]
per_token_amax, _ = torch.max(
input=torch.abs(hidden_states), dim=-1, keepdim=True
)
per_token_scale = per_token_amax / dtypeMax
per_token_scale[per_token_scale == 0] = 1
# quant hidden_states
y = (hidden_states / per_token_scale).to(dtype=quant_dtype)
y_scale = per_token_scale.to(scale_dtype)
return y, y_scale
def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False):
assert quant_dtype == dtypes.fp4x2
block_size = 32
F8E8M0_EXP_BIAS = 127
F4E2M1_MAX = 6.0
MAX_POW2 = int(torch.log2(torch.tensor(F4E2M1_MAX, dtype=torch.float32)).item())
# dtypeMax = F4E2M1_MAX
dtypeMax = 2.0**MAX_POW2
shape_original = x.shape
x = x.view(-1, shape_original[-1])
m, n = x.shape
x = x.view(-1, block_size)
max_abs = torch.amax(torch.abs(x.float()), 1)
# max_abs = max_abs.view(torch.int32)
# max_abs = ((max_abs + 0x200000) & 0xFF800000).view(torch.float32)
# fp8e8m0fnu_from_fp32_value
scale_e8m0_biased = fp4_utils.f32_to_e8m0(max_abs / dtypeMax)
# Float8_e8m0fnu to float
scale_f32 = fp4_utils.e8m0_to_f32(scale_e8m0_biased)
y = x.float() / scale_f32.view(-1, 1)
y = fp4_utils.f32_to_mxfp4(y)
y = y.view(*shape_original[:-1], -1)
scale = scale_e8m0_biased.view(m, -1).view(torch.uint8)
if shuffle:
scale = fp4_utils.e8m0_shuffle(scale)
return y, scale.view(dtypes.fp8_e8m0)
def per_tensor_quant(
x, scale=None, scale_dtype=dtypes.fp32, quant_dtype=dtypes.i8, dtypeMax=None
):
x = x.to(dtypes.fp32)
if scale is None:
if dtypeMax is None:
dtypeMax = get_dtype_max(quant_dtype)
scale = torch.abs(x).max() / dtypeMax
y = x / scale
return y.to(quant_dtype), scale.view(1).to(scale_dtype)
def per_block_quant_wrapper(block_shape=(1, 128)):
def decorator(per_token_quant_func):
def wrapper(x, scale=None, quant_dtype=dtypes.i8):
blk_m, blk_n = block_shape
assert (
x.shape[-1] % blk_n == 0
), f"block size {blk_n} not match {x.shape[-1]}"
assert blk_m == 1, "only support 1xN block, TODO: support MxN"
m, n = x.shape
x = x.view(-1, blk_n)
y, scale = per_token_quant_func(x, scale=scale, quant_dtype=quant_dtype)
return y.view(m, n), scale.view(m, n // blk_n)
return wrapper
return decorator
@functools.lru_cache()
def get_torch_quant(qType):
tmp = {
QuantType.No: lambda *a, **k: (a[0], None),
QuantType.per_Tensor: per_tensor_quant,
QuantType.per_Token: pertoken_quant,
QuantType.per_1x32: per_1x32_f4_quant,
QuantType.per_1x128: per_block_quant_wrapper((1, 128))(pertoken_quant),
}
def raise_NotImplementedError(*a, **k):
raise NotImplementedError(f"unsupported quant type {qType=}")
return tmp.get(qType, raise_NotImplementedError)
@functools.lru_cache()
def get_hip_quant(qType):
tmp = {
QuantType.No.value: lambda *a, **k: (a[0], None),
QuantType.per_Tensor.value: per_tensor_quant_hip,
QuantType.per_Token.value: per_token_quant_hip,
QuantType.per_1x32.value: per_1x32_f4_quant_hip,
QuantType.per_1x128.value: functools.partial(
per_group_quant_hip, group_size=128
),
}
def raise_NotImplementedError(*a, **k):
raise NotImplementedError(f"unsupported quant type {qType=}")
return tmp.get(qType.value, raise_NotImplementedError)
@functools.lru_cache()
def get_triton_quant(qType):
tmp = {
QuantType.No: lambda *a, **k: (a[0], None),
QuantType.per_Tensor: per_tensor_quant_triton,
QuantType.per_Token: per_token_quant_triton,
QuantType.per_1x32: per_1x32_f4_quant_triton,
QuantType.per_1x128: per_block_quant_wrapper((1, 128))(per_token_quant_triton),
}
def raise_NotImplementedError(*a, **k):
raise NotImplementedError(f"unsupported quant type {qType=}")
return tmp.get(qType, raise_NotImplementedError)
def per_token_quant_hip(
x,
scale=None,
quant_dtype=dtypes.i8,
num_rows: Optional[torch.tensor] = None,
num_rows_factor=1,
):
shape = x.shape
device = x.device
if scale is None:
scale = torch.empty((*shape[:-1], 1), dtype=dtypes.fp32, device=device)
else:
raise ValueError("unsupported: static per token quant")
if 1:
y = torch.empty(shape, dtype=quant_dtype, device=device)
dynamic_per_token_scaled_quant(
y, x, scale, num_rows=num_rows, num_rows_factor=num_rows_factor
)
elif quant_dtype == dtypes.i8:
M, N = x.view(-1, shape[-1]).shape
y = torch.empty((M, N), dtype=dtypes.i8, device=device)
scale = torch.empty(M, dtype=dtypes.fp32, device=device)
smooth_scale = torch.ones(N, dtype=dtypes.fp32, device=device)
smoothquant_fwd(y, x, smooth_scale, scale)
y = y.view(shape)
else:
raise ValueError(f"unsupported: {quant_dtype=}")
# print("finished per token quant hip")
return y, scale
def per_group_quant_hip(
x,
scale=None,
quant_dtype=dtypes.i8,
group_size=128,
transpose_scale=False,
num_rows: Optional[torch.tensor] = None,
num_rows_factor=1,
):
shape = x.shape
device = x.device
if scale is None:
scale = torch.empty(
(*shape[:-1], shape[-1] // group_size), dtype=dtypes.fp32, device=device
)
else:
raise ValueError("unsupported: static per token quant")
assert group_size in [
32,
64,
128,
], f"unsupported group size {group_size=}, only support [32, 64, 128]"
y = torch.empty(shape, dtype=quant_dtype, device=device)
dynamic_per_token_scaled_quant(
y,
x.view(-1, group_size),
scale,
shuffle_scale=transpose_scale,
num_rows=num_rows,
num_rows_factor=num_rows_factor,
)
return y, scale
def per_1x32_f4_quant_hip(
x,
scale=None,
quant_dtype=dtypes.fp4x2,
shuffle=False,
num_rows: Optional[torch.tensor] = None,
num_rows_factor=1,
):
m, n = x.shape
assert quant_dtype == dtypes.fp4x2
assert n % 2 == 0
device = x.device
if scale is None:
if shuffle:
scale = (
torch.empty(
(
(m + 255) // 256 * 256,
(n // 32 + 7) // 8 * 8,
),
dtype=torch.uint8,
device=device,
)
# .fill_(0x7F)
.view(dtypes.fp8_e8m0)
)
else:
scale = (
torch.empty(
(m, n // 32),
dtype=torch.uint8,
device=device,
)
# .fill_(0x7F)
.view(dtypes.fp8_e8m0)
)
else:
raise ValueError("unsupported: static per token quant")
y = torch.empty(m, n // 2, dtype=quant_dtype, device=device)
dynamic_per_group_scaled_quant_fp4(
y,
x,
scale,
32,
shuffle_scale=shuffle,
num_rows=num_rows,
num_rows_factor=num_rows_factor,
)
return y, scale
def per_tensor_quant_hip(
x,
scale=None,
quant_dtype=dtypes.i8,
num_rows: Optional[torch.tensor] = None,
num_rows_factor=1,
):
assert num_rows is None, "num_rows is not supported for per_tensor_quant_hip"
y = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
if quant_dtype in [dtypes.fp8, dtypes.i8]:
if scale is None:
scale = torch.empty(1, dtype=dtypes.fp32, device=x.device)
dynamic_per_tensor_quant(y, x, scale)
else:
static_per_tensor_quant(y, x, scale)
else:
raise ValueError(f"unsupported: {quant_dtype=}")
return y, scale.view(1)
def per_token_quant_triton(x, scale=None, quant_dtype=dtypes.i8):
shape = x.shape
device = x.device
y = torch.empty(shape, dtype=quant_dtype, device=device)
if scale is None:
scale = torch.empty((*shape[:-1], 1), dtype=dtypes.fp32, device=device)
triton.quant.dynamic_per_token_quant_fp8_i8(y, x.view(-1, x.shape[-1]), scale)
else:
raise ValueError("unsupported: static per token quant")
return y, scale
def per_1x32_f4_quant_triton(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False):
assert quant_dtype == dtypes.fp4x2
# y, scale = triton.quant.dynamic_mxfp4_quant(x)
y, scale = fp4_utils.dynamic_mxfp4_quant(x, shuffle=shuffle)
return y.view(quant_dtype), scale
def per_tensor_quant_triton(x, scale=None, quant_dtype=dtypes.i8):
y = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
x = x.view(-1, x.shape[-1])
if scale is None:
scale = torch.zeros(1, dtype=dtypes.fp32, device=x.device)
triton.quant.dynamic_per_tensor_quant_fp8_i8(y, x, scale)
else:
triton.quant.static_per_tensor_quant_fp8_i8(y, x, scale)
return y, scale
@functools.lru_cache()
def get_torch_act(aType):
tmp = {
ActivationType.No: lambda *a, **k: a[0],
ActivationType.Silu: F.silu,
ActivationType.Gelu: F.gelu,
}
return tmp.get(aType, NotImplementedError)
@compile_ops("module_quant")
def static_per_tensor_quant(out: Tensor, input: Tensor, scale: Tensor) -> None: ...
@compile_ops("module_quant")
def dynamic_per_tensor_quant(out: Tensor, input: Tensor, scale: Tensor) -> None: ...
@compile_ops("module_quant")
def dynamic_per_token_scaled_quant(
out: torch.Tensor,
input: torch.Tensor,
scales: torch.Tensor,
scale_ub: Optional[torch.Tensor] = None,
shuffle_scale: bool = False,
num_rows: Optional[torch.Tensor] = None,
num_rows_factor: int = 1,
) -> None: ...
@compile_ops("module_quant")
def dynamic_per_group_scaled_quant_fp4(
out: Tensor,
input: Tensor,
scales: Tensor,
group_size: Optional[int] = 32,
shuffle_scale: bool = True,
num_rows: Optional[Tensor] = None,
num_rows_factor: int = 1,
) -> None:
"""
Only support group_size in [32, 64, 128]
"""
...
@compile_ops("module_quant")
def smooth_per_token_scaled_quant(
out: torch.Tensor,
input: torch.Tensor,
scales: torch.Tensor,
smooth_scale: torch.Tensor,
smooth_scale_map: Optional[torch.Tensor] = None,
shuffle_scale: bool = False,
num_rows: Optional[torch.Tensor] = None,
num_rows_factor: int = 1,
) -> None: ...
@compile_ops("module_quant")
def partial_transpose(
out: Tensor,
input: Tensor,
num_rows: Tensor,
) -> None: ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from ..jit.core import compile_ops
from typing import Optional
MD_NAME = "module_rmsnorm"
@compile_ops("module_rmsnorm")
def rms_norm_cu(
out: Tensor,
input: Tensor,
weight: Tensor,
epsilon: float,
) -> None:
"""
Cuda version of rmsnorm
"""
...
@compile_ops("module_rmsnorm")
def fused_add_rms_norm_cu(
input: Tensor, # input/out
residual_in: Tensor, # residual_in/out
weight: Tensor,
epsilon: float,
) -> None:
"""
Cuda version of rmsnorm fused add
"""
...
def gen_rms_norm_fake_tensor(
input: Tensor,
weight: Tensor,
epsilon: float,
) -> Tensor:
return torch.empty_like(input, dtype=input.dtype, device=input.device)
@compile_ops(
"module_rmsnorm", fc_name="rmsnorm2d_fwd", gen_fake=gen_rms_norm_fake_tensor
)
def rms_norm(
input: Tensor,
weight: Tensor,
epsilon: float,
) -> Tensor:
"""
CK version of rmsnorm
"""
...
@compile_ops("module_rmsnorm", gen_fake=gen_rms_norm_fake_tensor)
def rmsnorm2d_fwd(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> Tensor: ...
@compile_ops("module_rmsnorm")
def rmsnorm2d_fwd_with_add(
out: Tensor,
input: Tensor,
residual_in: Tensor,
residual_out: Tensor,
weight: Tensor,
epsilon: float,
) -> None: ...
@compile_ops("module_rmsnorm")
def rmsnorm2d_fwd_with_smoothquant(
out: Tensor,
input: Tensor,
xscale: Tensor,
yscale: Tensor,
weight: Tensor,
epsilon: float,
) -> None: ...
@compile_ops("module_rmsnorm")
def rmsnorm2d_fwd_with_add_smoothquant(
out: Tensor,
input: Tensor,
residual_in: Tensor,
residual_out: Tensor,
xscale: Tensor,
yscale: Tensor,
weight: Tensor,
epsilon: float,
out_before_quant: Optional[Tensor] = None,
) -> None: ...
@compile_ops("module_rmsnorm")
def rmsnorm2d_fwd_with_dynamicquant(
out: Tensor,
input: Tensor,
yscale: Tensor,
weight: Tensor,
epsilon: float,
) -> None: ...
@compile_ops("module_rmsnorm")
def rmsnorm2d_fwd_with_add_dynamicquant(
out: Tensor,
input: Tensor,
residual_in: Tensor,
residual_out: Tensor,
yscale: Tensor,
weight: Tensor,
epsilon: float,
) -> None: ...
# SPDX-License-Identifier: MIT
from torch import Tensor, empty, empty_like, autograd
from typing import Tuple, Union
from ..jit.core import compile_ops
MD_NAME = "module_rope"
@compile_ops("module_rope_general_fwd")
def rope_fwd_impl(
output: Tensor,
input: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of traditional RoPE (Rotary Position Embedding).
Input and output should be in "sbhd" format and freqs should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, it should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_bwd")
def rope_bwd_impl(
input_grads: Tensor,
output_grads: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Backward propagation of traditional RoPE (Rotary Position Embedding).
Input and output should be in "sbhd" format and freqs should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, it should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_fwd")
def rope_2c_fwd_impl(
output_x: Tensor,
output_y: Tensor,
input_x: Tensor,
input_y: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of traditional RoPE (Rotary Position Embedding) on two channels.
Input and output should be in "sbhd" format and freqs should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, it should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_bwd")
def rope_2c_bwd_impl(
input_grads_x: Tensor,
input_grads_y: Tensor,
output_grads_x: Tensor,
output_grads_y: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Backward propagation of traditional RoPE (Rotary Position Embedding) on two channels.
Input and output should be in "sbhd" format and freqs should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, it should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_fwd")
def rope_cached_fwd_impl(
output: Tensor,
input: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin.
Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, they should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_bwd")
def rope_cached_bwd_impl(
input_grads: Tensor,
output_grads: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Backward propagation of RoPE (Rotary Position Embedding) with cached cos and sin.
Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, they should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_fwd")
def rope_cached_2c_fwd_impl(
output_x: Tensor,
output_y: Tensor,
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin on two channels.
Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, they should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_bwd")
def rope_cached_2c_bwd_impl(
input_grads_x: Tensor,
input_grads_y: Tensor,
output_grads_x: Tensor,
output_grads_y: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Backward propagation of RoPE (Rotary Position Embedding) with cached cos and sin on two channels.
Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, they should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_pos_fwd")
def rope_cached_positions_fwd_impl(
output: Tensor,
input: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin with positions and offsets
on one channel. Offsets here is optional. Both positions and offsets should be in [s, b].
Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, they should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_pos_fwd")
def rope_cached_positions_2c_fwd_impl(
output_x: Tensor,
output_y: Tensor,
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin with positions and offsets
on two channels. Offsets here is optional. Both positions and offsets should be in [s, b].
Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, they should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_pos_fwd")
def rope_cached_positions_offsets_fwd_impl(
output: Tensor,
input: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
offsets: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin with positions and offsets
on one channel. Offsets here is optional. Both positions and offsets should be in [s, b].
Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, they should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_pos_fwd")
def rope_cached_positions_offsets_2c_fwd_impl(
output_x: Tensor,
output_y: Tensor,
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
offsets: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of RoPE (Rotary Position Embedding) with cached cos and sin with positions and offsets
on two channels. Offsets here is optional. Both positions and offsets should be in [s, b].
Input and output should be in "sbhd" format, and cos and sin should be in shape of [s, 1, 1, d // 2]
if reuse_freqs_front_part is true. Otherwise, they should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_fwd")
def rope_thd_fwd_impl(
output: Tensor,
input: Tensor,
cu_seqlens: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of RoPE (Rotary Position Embedding) with input sizes: (t, h, d).
where t is cumulative sum of sequence lengths.
Freqs should be in shape of [s, 1, 1, d // 2] if reuse_freqs_front_part is true. Otherwise,
it should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_bwd")
def rope_thd_bwd_impl(
input_grads: Tensor,
output_grads: Tensor,
cu_seqlens: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Backward propagation of RoPE (Rotary Position Embedding) with input sizes: (t, h, d).
where t is cumulative sum of sequence lengths.
Freqs should be in shape of [s, 1, 1, d // 2] if reuse_freqs_front_part is true. Otherwise,
it should be in [s, 1, 1, d].
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_fwd")
def rope_2d_fwd_impl(
output: Tensor,
input: Tensor,
cos_h: Tensor,
sin_h: Tensor,
cos_w: Tensor,
sin_w: Tensor,
img_height: int,
img_width: int,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Forward propagation of RoPE (Rotary Position Embedding) with 2D image as input.
Input and output should be in (b, s, h, d) where s = H * W.
cos_h and sin_h are in (1, H', 1, h, d // 4) if reuse_freqs_front_part is true. Otherwise,
it should be in (1, H', 1, h, d // 2) where H' >= H.
cos_w and sin_w are in (1, 1, W', h, d // 2) if reuse_freqs_front_part is true. Otherwise,
it should be in (1, 1, W', h, d // 2) where W' >= W.
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
@compile_ops("module_rope_general_bwd")
def rope_2d_bwd_impl(
input_grads: Tensor,
output_grads: Tensor,
cos_h: Tensor,
sin_h: Tensor,
cos_w: Tensor,
sin_w: Tensor,
img_height: int,
img_width: int,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> None:
"""
Backward propagation of RoPE (Rotary Position Embedding) with 2D image as input.
output_grads and input_grads should be in (b, s, h, d) where s = H * W.
cos_h and sin_h are in (1, H', 1, h, d // 4) if reuse_freqs_front_part is true. Otherwise,
it should be in (1, H', 1, h, d // 2) where H' >= H.
cos_w and sin_w are in (1, 1, W', h, d // 2) if reuse_freqs_front_part is true. Otherwise,
it should be in (1, 1, W', h, d // 2) where W' >= W.
rotate_style: 0 - NEOX style which rotates the 2nd half of elements, 1 - GPT-J style which rotates odd part.
When rotate dim is smaller than d, front part is just copied if nope_first is true, or later part is copied
if nope_first is false. Rotate dim is freqs/cos/sin.shape[-1] * 2 if reuse_freqs_front_part else 1.
"""
...
def rope_fwd(
input: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h, d = input.shape
output = (
empty(
(b, s, h, d), dtype=input.dtype, device=input.device, requires_grad=False
).transpose(0, 1)
if transpose_output
else empty_like(input, requires_grad=False)
)
rope_fwd_impl(
output, input, freqs, rotate_style, reuse_freqs_front_part, nope_first
)
return output
def rope_fwd_inplace(
input: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_fwd_impl(input, input, freqs, rotate_style, reuse_freqs_front_part, nope_first)
def rope_bwd(
output_grads: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h, d = output_grads.shape
input_grads = (
empty(
(b, s, h, d),
dtype=output_grads.dtype,
device=output_grads.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(output_grads, requires_grad=False)
)
rope_bwd_impl(
input_grads,
output_grads,
freqs,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return input_grads
def rope_2c_fwd(
input_x: Tensor,
input_y: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h_x, d = input_x.shape
h_y = input_y.shape[2]
output_x = (
empty(
(b, s, h_x, d),
dtype=input_x.dtype,
device=input_x.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(input_x, requires_grad=False)
)
output_y = (
empty(
(b, s, h_y, d),
dtype=input_y.dtype,
device=input_y.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(input_y, requires_grad=False)
)
rope_2c_fwd_impl(
output_x,
output_y,
input_x,
input_y,
freqs,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return output_x, output_y
def rope_2c_fwd_inplace(
input_x: Tensor,
input_y: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_2c_fwd_impl(
input_x,
input_y,
input_x,
input_y,
freqs,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
def rope_2c_bwd(
output_grads_x: Tensor,
output_grads_y: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h_x, d = output_grads_x.shape
h_y = output_grads_y.shape[2]
input_grads_x = (
empty(
(b, s, h_x, d),
dtype=output_grads_x.dtype,
device=output_grads_x.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(output_grads_x, requires_grad=False)
)
input_grads_y = (
empty(
(b, s, h_y, d),
dtype=output_grads_y.dtype,
device=output_grads_y.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(output_grads_y, requires_grad=False)
)
rope_2c_bwd_impl(
input_grads_x,
input_grads_y,
output_grads_x,
output_grads_y,
freqs,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return input_grads_x, input_grads_y
def rope_cached_fwd(
input: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h, d = input.shape
output = (
empty(
(b, s, h, d), dtype=input.dtype, device=input.device, requires_grad=False
).transpose(0, 1)
if transpose_output
else empty_like(input, requires_grad=False)
)
rope_cached_fwd_impl(
output, input, cos, sin, rotate_style, reuse_freqs_front_part, nope_first
)
return output
def rope_cached_fwd_inplace(
input: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_cached_fwd_impl(
input, input, cos, sin, rotate_style, reuse_freqs_front_part, nope_first
)
def rope_cached_bwd(
output_grads: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h, d = output_grads.shape
input_grads = (
empty(
(b, s, h, d),
dtype=output_grads.dtype,
device=output_grads.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(output_grads, requires_grad=False)
)
rope_cached_bwd_impl(
input_grads,
output_grads,
cos,
sin,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return input_grads
def rope_cached_2c_fwd(
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h_x, d = input_x.shape
h_y = input_y.shape[2]
output_x = (
empty(
(b, s, h_x, d),
dtype=input_x.dtype,
device=input_x.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(input_x, requires_grad=False)
)
output_y = (
empty(
(b, s, h_y, d),
dtype=input_y.dtype,
device=input_y.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(input_y, requires_grad=False)
)
rope_cached_2c_fwd_impl(
output_x,
output_y,
input_x,
input_y,
cos,
sin,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return output_x, output_y
def rope_cached_2c_fwd_inplace(
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_cached_2c_fwd_impl(
input_x,
input_y,
input_x,
input_y,
cos,
sin,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
def rope_cached_2c_bwd(
output_grads_x: Tensor,
output_grads_y: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h_x, d = output_grads_x.shape
h_y = output_grads_y.shape[2]
input_grads_x = (
empty(
(b, s, h_x, d),
dtype=output_grads_x.dtype,
device=output_grads_x.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(output_grads_x, requires_grad=False)
)
input_grads_y = (
empty(
(b, s, h_y, d),
dtype=output_grads_y.dtype,
device=output_grads_y.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(output_grads_y, requires_grad=False)
)
rope_cached_2c_bwd_impl(
input_grads_x,
input_grads_y,
output_grads_x,
output_grads_y,
cos,
sin,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return input_grads_x, input_grads_y
def rope_cached_positions_fwd(
input: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h, d = input.shape
output = (
empty(
(b, s, h, d), dtype=input.dtype, device=input.device, requires_grad=False
).transpose(0, 1)
if transpose_output
else empty_like(input, requires_grad=False)
)
rope_cached_positions_fwd_impl(
output,
input,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return output
def rope_cached_positions_2c_fwd(
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h_x, d = input_x.shape
h_y = input_y.shape[2]
output_x = (
empty(
(b, s, h_x, d),
dtype=input_x.dtype,
device=input_x.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(input_x, requires_grad=False)
)
output_y = (
empty(
(b, s, h_y, d),
dtype=input_y.dtype,
device=input_y.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(input_y, requires_grad=False)
)
rope_cached_positions_2c_fwd_impl(
output_x,
output_y,
input_x,
input_y,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return output_x, output_y
def rope_cached_positions_fwd_inplace(
input: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_cached_positions_fwd_impl(
input,
input,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
def rope_cached_positions_2c_fwd_inplace(
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_cached_positions_2c_fwd_impl(
input_x,
input_y,
input_x,
input_y,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
def rope_cached_positions_offsets_fwd(
input: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
offsets: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h, d = input.shape
output = (
empty(
(b, s, h, d), dtype=input.dtype, device=input.device, requires_grad=False
).transpose(0, 1)
if transpose_output
else empty_like(input, requires_grad=False)
)
rope_cached_positions_offsets_fwd_impl(
output,
input,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return output
def rope_cached_positions_offsets_2c_fwd(
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
offsets: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
s, b, h_x, d = input_x.shape
h_y = input_y.shape[2]
output_x = (
empty(
(b, s, h_x, d),
dtype=input_x.dtype,
device=input_x.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(input_x, requires_grad=False)
)
output_y = (
empty(
(b, s, h_y, d),
dtype=input_y.dtype,
device=input_y.device,
requires_grad=False,
).transpose(0, 1)
if transpose_output
else empty_like(input_y, requires_grad=False)
)
rope_cached_positions_offsets_2c_fwd_impl(
output_x,
output_y,
input_x,
input_y,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return output_x, output_y
def rope_cached_positions_offsets_fwd_inplace(
input: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
offsets: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_cached_positions_offsets_fwd_impl(
input,
input,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
def rope_cached_positions_offsets_2c_fwd_inplace(
input_x: Tensor,
input_y: Tensor,
cos: Tensor,
sin: Tensor,
positions: Tensor,
offsets: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_cached_positions_offsets_2c_fwd_impl(
input_x,
input_y,
input_x,
input_y,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
def rope_thd_fwd(
input: Tensor,
cu_seqlens: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
output = empty_like(input, requires_grad=False)
rope_thd_fwd_impl(
output,
input,
cu_seqlens,
freqs,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return output
def rope_thd_fwd_inplace(
input: Tensor,
cu_seqlens: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_thd_fwd_impl(
input,
input,
cu_seqlens,
freqs,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
def rope_thd_bwd(
output_grads: Tensor,
cu_seqlens: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
input_grads = empty_like(output_grads, requires_grad=False)
rope_thd_bwd_impl(
input_grads,
output_grads,
cu_seqlens,
freqs,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return input_grads
def rope_2d_fwd(
input: Tensor,
cos_h: Tensor,
sin_h: Tensor,
cos_w: Tensor,
sin_w: Tensor,
img_height: int,
img_width: int,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
output = empty_like(input, requires_grad=False)
rope_2d_fwd_impl(
output,
input,
cos_h,
sin_h,
cos_w,
sin_w,
img_height,
img_width,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return output
def rope_2d_fwd_inplace(
input: Tensor,
cos_h: Tensor,
sin_h: Tensor,
cos_w: Tensor,
sin_w: Tensor,
img_height: int,
img_width: int,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
rope_2d_fwd_impl(
input,
input,
cos_h,
sin_h,
cos_w,
sin_w,
img_height,
img_width,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
def rope_2d_bwd(
output_grads: Tensor,
cos_h: Tensor,
sin_h: Tensor,
cos_w: Tensor,
sin_w: Tensor,
img_height: int,
img_width: int,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
input_grads = empty_like(output_grads, requires_grad=False)
rope_2d_bwd_impl(
input_grads,
output_grads,
cos_h,
sin_h,
cos_w,
sin_w,
img_height,
img_width,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
return input_grads
class RoPE(autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
ctx.rotate_style = rotate_style
ctx.reuse_freqs_front_part = reuse_freqs_front_part
ctx.nope_first = nope_first
ctx.transpose_output = transpose_output
ctx.save_for_backward(freqs)
return rope_fwd(
x, freqs, rotate_style, reuse_freqs_front_part, nope_first, transpose_output
)
@staticmethod
def backward(ctx, output_grads: Tensor) -> Tuple[Union[Tensor, None], ...]:
(freqs,) = ctx.saved_tensors
return (
rope_bwd(
output_grads,
freqs,
ctx.rotate_style,
ctx.reuse_freqs_front_part,
ctx.nope_first,
ctx.transpose_output,
),
None,
None,
)
class RoPECached(autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
cos: Tensor,
sin: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
transpose_output: bool = False,
) -> Tensor:
ctx.rotate_style = rotate_style
ctx.reuse_freqs_front_part = reuse_freqs_front_part
ctx.nope_first = nope_first
ctx.transpose_output = transpose_output
ctx.save_for_backward(cos, sin)
return rope_cached_fwd(
x,
cos,
sin,
rotate_style,
reuse_freqs_front_part,
nope_first,
transpose_output,
)
@staticmethod
def backward(ctx, output_grads) -> Tuple[Union[Tensor, None], ...]:
cos, sin = ctx.saved_tensors
return (
rope_cached_bwd(
output_grads,
cos,
sin,
ctx.rotate_style,
ctx.reuse_freqs_front_part,
ctx.nope_first,
ctx.transpose_output,
),
None,
None,
)
class RoPETHD(autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
cu_seqlens: Tensor,
freqs: Tensor,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
):
ctx.rotate_style = rotate_style
ctx.reuse_freqs_front_part = reuse_freqs_front_part
ctx.nope_first = nope_first
ctx.save_for_backward(cu_seqlens, freqs)
return rope_thd_fwd(
x, cu_seqlens, freqs, rotate_style, reuse_freqs_front_part, nope_first
)
@staticmethod
def backward(ctx, output_grads) -> Tuple[Union[Tensor, None], ...]:
cu_seqlens, freqs = ctx.saved_tensors
return (
rope_thd_bwd(
output_grads,
cu_seqlens,
freqs,
ctx.rotate_style,
ctx.reuse_freqs_front_part,
ctx.nope_first,
),
None,
None,
)
class RoPE2D(autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
cos_height: Tensor,
sin_height: Tensor,
cos_width: Tensor,
sin_width: Tensor,
img_height: int,
img_width: int,
rotate_style: int,
reuse_freqs_front_part: bool,
nope_first: bool,
) -> Tensor:
ctx.img_height = img_height
ctx.img_width = img_width
ctx.rotate_style = rotate_style
ctx.reuse_freqs_front_part = reuse_freqs_front_part
ctx.nope_first = nope_first
ctx.save_for_backward(cos_height, sin_height, cos_width, sin_width)
return rope_2d_fwd(
x,
cos_height,
sin_height,
cos_width,
sin_width,
img_height,
img_width,
rotate_style,
reuse_freqs_front_part,
nope_first,
)
@staticmethod
def backward(ctx, output_grads) -> Tuple[Union[Tensor, None], ...]:
cos_height, sin_height, cos_width, sin_width = ctx.saved_tensors
return (
rope_2d_bwd(
output_grads,
cos_height,
sin_height,
cos_width,
sin_width,
ctx.img_height,
ctx.img_height,
ctx.rotate_style,
ctx.reuse_freqs_front_part,
ctx.nope_first,
),
None,
None,
)
# SPDX-License-Identifier: MIT
import torch
import numpy as np
# Moe_c Shuffle Function
#=================================================================================================================
def moe_layout_shuffle_gemm1(weight):
return _w8a8_marlin_weight_1(weight)
def moe_layout_shuffle_gemm2(weight):
return _w8a8_marlin_weight_2(weight)
def w4a8_moe_layout_shuffle_gemm1(weight):
return _w4a8_gemm1_weight_shuffle(weight)
def w4a8_moe_layout_shuffle_gemm2(weight):
return _w4a8_gemm2_weight_shuffle(weight)
#w4a16
def w4a16_marlin_weight_1(weight_input # [size_n, size_k// 2 ]
):
w1_qweight = weight_input
e,n,k=w1_qweight.shape
k = k * 2
w1_qweight_uint32 = w1_qweight.view(-1).view(torch.uint32)
new_shape = (e, n // 16, 16, k // 32, 4) # uint32张量的形状
w1_qweight_uint32_reshaped = w1_qweight_uint32.view(new_shape)
w1_qweight_uint32_transposed = w1_qweight_uint32_reshaped.transpose(2, 3).contiguous()
new_shape = (e, n // 16, k // 128, 4, 16, 4)
w1_new_trans = w1_qweight_uint32_transposed.view(new_shape)
w1_qweight_shuffle = w1_new_trans.transpose(1, 2).contiguous()
return w1_qweight_shuffle
def w4a16_marlin_weight_2(weight_input # [size_n, size_k// 2 ]
):
w2_qweight = weight_input
e,k,n=w2_qweight.shape
n = n * 2
w2_qweight_uint32 = w2_qweight.view(-1).view(torch.uint32)
new_shape = (e, k // 16, 16, n // 32, 4) # uint32张量的形状
w2_qweight_uint32_reshaped = w2_qweight_uint32.view(new_shape)
w2_qweight_uint32_transposed = w2_qweight_uint32_reshaped.transpose(2, 3).contiguous()
new_shape = (e, k // 16, n // 128, 4, 16, 4)
w2_new_trans = w2_qweight_uint32_transposed.view(new_shape)
w2_qweight_shuffle = w2_new_trans.transpose(1, 2).contiguous()
return w2_qweight_shuffle
#w8a8
def _w8a8_marlin_weight_1(weight_input # [size_n, size_k// 2 ]
):
weight = weight_input
weight = weight.permute(0,2,1)
marlin_q_w = _marlin_weights(weight, k_tile=64, n_tile=16, pack_factor=8)
return marlin_q_w
def _w8a8_marlin_weight_2(weight_input # [size_n, size_k// 2 ]
):
weight = weight_input
weight = weight.permute(0,2,1)
marlin_q_w = _marlin_weights_2(weight, k_tile=64, n_tile=16, pack_factor=8)
return marlin_q_w
def _marlin_weights(
q_w,
k_tile=64,
n_tile=16,
pack_factor=8):
# 7168, 256
e,size_k, size_n = q_w.shape
q_w = q_w.reshape(e,size_k // k_tile, k_tile, size_n )
q_w = q_w.permute(0,1,3,2).contiguous()
q_w = q_w.reshape(e,size_k // k_tile, size_n * k_tile)
return q_w
def _marlin_weights_2(
q_w,
k_tile=64,
n_tile=16,
pack_factor=8):
# 128 7168
e, size_k, size_n = q_w.shape
q_w = q_w.reshape(e,size_k // k_tile, k_tile, size_n //n_tile , n_tile )
q_w = q_w.permute((0,1, 3, 4, 2)).contiguous()
q_w = q_w.reshape(e, size_k // k_tile , size_n //n_tile , n_tile // 16 , 16, k_tile // 16 , 16 )
q_w = q_w.permute(0,1,2,3,5,4,6).contiguous()
return q_w
# w4a8
def _w4a8_gemm1_weight_shuffle(w4a8_w):
full_w4a8_w = w4a8_w
full_w4a8_w = full_w4a8_w.T
k_tile=32
n_tile=256
size_k, size_n = full_w4a8_w.shape
full_w4a8_w = full_w4a8_w.reshape(size_k // k_tile, k_tile, size_n //n_tile , n_tile )
full_w4a8_w = full_w4a8_w.permute((0, 2, 3, 1)).contiguous()
full_w4a8_w = full_w4a8_w.reshape(size_k // k_tile , size_n //n_tile , n_tile // 32 , 32, k_tile // 8 , 8 )
full_w4a8_w = full_w4a8_w.permute(0,1,2,4,3,5).contiguous()
return full_w4a8_w
def _w4a8_gemm2_weight_shuffle(w4a8_w):
full_w4a8_w = w4a8_w
full_w4a8_w = full_w4a8_w.T
k_tile=32
n_tile=256
size_k, size_n = full_w4a8_w.shape
full_w4a8_w = full_w4a8_w.reshape(size_k // k_tile, k_tile, size_n //n_tile , n_tile )
full_w4a8_w = full_w4a8_w.permute((0, 2, 3, 1)).contiguous()
full_w4a8_w = full_w4a8_w.reshape(size_k // k_tile , size_n //n_tile , n_tile // 32 , 32, k_tile // 8 , 8 )
full_w4a8_w = full_w4a8_w.permute(0,1,2,4,3,5).contiguous()
return full_w4a8_w
#=======================================================Moe_c Shuffle Function================================================================
def asm_shuffle_weight_b8(x: torch.Tensor, stage: torch.int32 = 1) -> torch.Tensor:
# Hardcode BLOCK_K and BLOCK_N
assert x.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.int8, torch.float8_e4m3fn
]
if x.dtype == torch.int8 or x.dtype == torch.float8_e4m3fn:
N = 16
K = 16
IK = 64
IN = 64
BK = 256
BN = 128
if stage == 1:
if x.shape[-2] % 128 != 0 and x.shape[-2] % 64 == 0:
BN = 64
if stage == 2:
if x.shape[-1] % 128 == 0:
BK = 128
elif x.shape[-1] % 128 == 64:
BN = 64
BK = 64
elif x.shape[-1] % 128 == 96:
BN = 64
BK = 64
assert x.shape[-2] % BN == 0, f"{x.shape[-2]} % {BN} == {x.shape[-2] % BN }"
x_ = x
multiple = x.shape[-1] // BK * BK
part1 = x[:, :, :multiple]
### part1 shuffle
# 0, 1, 2, 3, 4, 5, 6, 7, 8
part1 = part1.view(-1, part1.shape[-2] // BN, BN // IN, IN // N, N, part1.shape[-1] // BK, BK // IK, IK // K, K)
part1 = part1.permute(0, 1, 5, 2, 6, 3, 7, 4, 8).contiguous()
part1 = part1.flatten(start_dim=1)
### part2 shuffle
part2 = x[:, :, multiple:]
IK = 32
BK = 32
# 0, 1, 2, 3, 4, 5, 6, 7, 8
part2 = part2.view(-1, part2.shape[-2] // BN, BN // IN, IN // N, N, part2.shape[-1] // BK, BK // IK, IK // K, K)
part2 = part2.permute(0, 1, 5, 2, 6, 3, 7, 4, 8).contiguous()
part2 = part2.flatten(start_dim=1)
### combine
x_ = torch.cat((part1, part2), dim=1)
x_ = x_.view(*x.shape)
return x_
elif x.dtype == torch.float16 or x.dtype == torch.bfloat16:
N = 16
K = 8
IK = 32
IN = 64
BK = 128
BN = 64
if stage == 2:
BK = 32
else:
assert False, f"not support {x.dtype}"
assert x.shape[-2] % BN == 0, f"{x.shape[-2]} % {BN} == {x.shape[-2] % BN }"
assert x.shape[-1] % BK == 0, f"{x.shape[-1]} % {BK} == {x.shape[-1] % BK }"
x_ = x
# 0, 1, 2, 3, 4, 5, 6, 7, 8
x_ = x_.view(-1, x.shape[-2] // BN, BN // IN, IN // N, N, x.shape[-1] // BK, BK // IK, IK // K, K)
x_ = x_.permute(0, 1, 5, 2, 6, 3, 7, 4, 8)
x_ = x_.contiguous()
x_ = x_.view(*x.shape)
return x_
def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Tensor:
# Hardcode BLOCK_K and BLOCK_N
IN, IK = layout
BK = IK * 2
K = 16 // x.element_size() if not use_int4 else 32
BN = IN
assert x.shape[-2] % BN == 0, f"{x.shape[-2]} % {BN} == {x.shape[-2] % BN }"
assert x.shape[-1] % BK == 0, f"{x.shape[-1]} % {BK} == {x.shape[-1] % BK }"
x_ = x
x_ = x_.view(-1, x.shape[-2] // BN, BN, x.shape[-1] // BK, BK // K, K)
x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
x_ = x_.view(*x.shape)
return x_
# TN Layout in -> CK Tiling Layout out
# layout(NWaves, NRepeat, NLane, NInterleave, NVec, KWaves, KRepeat, KLane, KVec)
def ck_shuffle_weight(x:torch.Tensor, layout=(4, 1, 16, 2, 1, 1, 4, 4, 8)) -> torch.Tensor:
NWaves, NRepeat, NLane, NInterleave, NVec, KWaves, KRepeat, KLane, KVec = layout
Block_N = NWaves * NRepeat * NLane * NInterleave * NVec
Block_K = KWaves * KRepeat * KLane * KVec
assert x.shape[-2] % Block_N == 0, f"{x.shape[-2]} % {Block_N} == {x.shape[-2] % Block_N }"
assert x.shape[-1] % Block_K == 0, f"{x.shape[-1]} % {Block_K} == {x.shape[-1] % Block_K }"
x_ = x
# (0, 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11)
x_ = x_.view(-1, x.shape[-2] // Block_N, NWaves, NRepeat, NLane, NInterleave, NVec, x.shape[-1] // Block_K, KWaves, KRepeat, KLane, KVec)
# x_ = x_.permute(0, 1, 7, 3, 5, 9, 8, 2, 10, 4, 6, 11)
x_ = x_.permute(0, 1, 7, 2, 8, 3, 5, 9, 10, 4, 6, 11)
x_ = x_.contiguous()
x_ = x_.view(-1, x.shape[-2] // Block_N, x.shape[-1] // Block_K, Block_N * Block_K)
return x_
# layout(NWaves, NRepeat, NLane, NInterleave, NVec, KWaves, KRepeat, KLane, KVec)
def ck_shuffle_weight_down(x:torch.Tensor, layout=(4, 2, 16, 1, 1, 1, 4, 4, 8)) -> torch.Tensor:
NWaves, NRepeat, NLane, NInterleave, NVec, KWaves, KRepeat, KLane, KVec = layout
Block_N = NWaves * NRepeat * NLane * NInterleave * NVec
Block_K = KWaves * KRepeat * KLane * KVec
assert x.shape[-2] % Block_N == 0, f"{x.shape[-2]} % {Block_N} == {x.shape[-2] % Block_N }"
assert x.shape[-1] % Block_K == 0, f"{x.shape[-1]} % {Block_K} == {x.shape[-1] % Block_K }"
x_ = x
# (0, 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11)
x_ = x_.view(-1, x.shape[-2] // Block_N, NWaves, NRepeat, NLane, NInterleave, NVec, x.shape[-1] // Block_K, KWaves, KRepeat, KLane, KVec)
x_ = x_.permute(0, 7, 1, 2, 8, 3, 5, 9, 10, 4, 6, 11) #down weight loop in N dim
x_ = x_.contiguous()
x_ = x_.view(-1, x.shape[-1] // Block_K, x.shape[-2] // Block_N, Block_N * Block_K)
return x_
def reverse_awq_order(tensor: torch.Tensor) -> torch.Tensor:
"""Reverse the AWQ order of the given tensor.
Args:
tensor: Input tensor to reorder
Returns:
Reordered tensor with bits masked to 4 bits
"""
bits = 4
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
reverse_order_tensor = torch.arange(
tensor.shape[-1],
dtype=torch.int32,
device=tensor.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
tensor = tensor[:, reverse_order_tensor] & 0xF
return tensor
def awq_reorder_and_repack(
qweight: torch.Tensor,
qzeros: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Reorder and pack weights and zeros using AWQ order.
This function unpacks the 4-bit quantized weights and zeros from int32,
applies reverse_awq_order to reorder them, and then packs them.
For weight, repack to [N, K//2]
For zeros, repack to [K//G, N//2]
Args:
qweight: Quantized weight tensor of shape [K, N // 8] with dtype int32
qzeros: Quantized zero points tensor of shape [K // G, N // 8] with dtype int32
Returns:
Tuple of (reordered_qweight, reordered_qzeros) both with dtype int8
"""
bits = 4
shifts = torch.arange(0, 32, bits, device=qweight.device)
K = qweight.shape[0]
N = qweight.shape[1] * 8
G = K // qzeros.shape[0]
# Unpack weights: [K, N//8] -> [K, N//8, 8] -> [K, N]
iweights = torch.bitwise_right_shift(
qweight[:, :, None],
shifts[None, None, :],
).to(torch.int8)
iweights = iweights.view(K, -1)
# Unpack zeros: [K//G, N//8] -> [K//G, N//8, 8] -> [K//G, N]
zeros = torch.bitwise_right_shift(
qzeros[:, :, None],
shifts[None, None, :],
).to(torch.int8)
zeros = zeros.view(K//G, -1)
# Apply reverse AWQ order to both tensors
iweights = reverse_awq_order(iweights)
zeros = reverse_awq_order(zeros)
# Mask to 4 bits
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
# Repack weight to int32 and pack along the K direction
# [K, N] -> [N, K]
iweights = iweights.transpose(1, 0).contiguous()
# Reshape to [N, K//2, 2] for weights
iweights_packed = iweights.view(N, -1, 2)
# Repack zeros to int8 and pack along the N direction
# Reshape to [K//G, N//2, 2] for zeros
zeros_packed = zeros.view(K//G, -1, 2)
# Pack 2 int4 values into int8 using bit shifts
# Direct packing: pack in the order they appear after reordering
packed_weights = torch.zeros([N, K//2], dtype=torch.int8, device=qweight.device)
packed_zeros = torch.zeros([K//G, N//2], dtype=torch.int8, device=zeros.device)
for i in range(2):
packed_weights |= (iweights_packed[:, :, i].to(torch.int8) << (i * bits))
packed_zeros |= (zeros_packed[:, :, i].to(torch.int8) << (i * bits))
return packed_weights, packed_zeros
# SPDX-License-Identifier: MIT
from .sparse_mla_fwd import tilelang_sparse_fwd, ref_sparse_mla_fwd_interface
__all__ = ["tilelang_sparse_fwd", "ref_sparse_mla_fwd_interface"]
# Auto-generated by tune_fp8_index.py. Do not edit manually.
from typing import Tuple, Dict
import bisect
M_REPR_TABLE = [1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 640, 768, 896, 1024, 1280, 1536, 1792, 2048, 2560, 3072, 3584, 4096]
N_REPR_TABLE = [64, 128, 256, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 24576, 32768, 40960, 49152, 57344, 65536, 73728, 81920, 90112, 98304, 106496, 114688, 122880, 131072]
CONFIG_MAP: Dict[Tuple[int, int], Tuple[int, int, int]] = {
(1, 64): (1, 128, 64),
(1, 128): (1, 128, 64),
(1, 256): (1, 128, 64),
(1, 512): (1, 128, 64),
(1, 768): (1, 128, 64),
(1, 1024): (1, 128, 64),
(1, 1536): (1, 128, 64),
(1, 2048): (1, 128, 64),
(1, 3072): (1, 128, 64),
(1, 4096): (1, 128, 64),
(1, 6144): (1, 128, 64),
(1, 8192): (1, 128, 64),
(1, 12288): (1, 128, 64),
(1, 16384): (1, 128, 64),
(1, 24576): (1, 128, 64),
(1, 32768): (1, 128, 64),
(1, 40960): (1, 128, 64),
(1, 49152): (1, 128, 64),
(1, 57344): (1, 128, 64),
(1, 65536): (1, 128, 64),
(1, 73728): (1, 128, 64),
(1, 81920): (1, 128, 64),
(1, 90112): (1, 128, 64),
(1, 98304): (1, 128, 64),
(1, 106496): (1, 128, 64),
(1, 114688): (1, 128, 64),
(1, 122880): (1, 128, 64),
(1, 131072): (1, 128, 64),
(2, 64): (2, 64, 64),
(2, 128): (2, 64, 64),
(2, 256): (2, 64, 64),
(2, 512): (2, 64, 64),
(2, 768): (2, 64, 64),
(2, 1024): (2, 64, 64),
(2, 1536): (2, 64, 64),
(2, 2048): (2, 64, 64),
(2, 3072): (2, 64, 64),
(2, 4096): (2, 64, 64),
(2, 6144): (2, 128, 64),
(2, 8192): (2, 128, 64),
(2, 12288): (2, 256, 64),
(2, 16384): (2, 256, 64),
(2, 24576): (2, 256, 64),
(2, 32768): (2, 256, 64),
(2, 40960): (2, 256, 64),
(2, 49152): (2, 256, 64),
(2, 57344): (2, 256, 64),
(2, 65536): (2, 256, 64),
(2, 73728): (2, 256, 64),
(2, 81920): (2, 512, 128),
(2, 90112): (2, 512, 128),
(2, 98304): (2, 512, 128),
(2, 106496): (2, 512, 128),
(2, 114688): (2, 512, 128),
(2, 122880): (2, 512, 128),
(2, 131072): (2, 512, 128),
(4, 64): (2, 64, 64),
(4, 128): (2, 64, 64),
(4, 256): (2, 64, 64),
(4, 512): (2, 64, 64),
(4, 768): (2, 64, 64),
(4, 1024): (2, 64, 64),
(4, 1536): (2, 64, 64),
(4, 2048): (2, 64, 64),
(4, 3072): (2, 64, 64),
(4, 4096): (2, 64, 64),
(4, 6144): (4, 128, 64),
(4, 8192): (4, 128, 64),
(4, 12288): (2, 256, 64),
(4, 16384): (4, 128, 64),
(4, 24576): (2, 256, 64),
(4, 32768): (2, 256, 64),
(4, 40960): (4, 256, 64),
(4, 49152): (4, 256, 64),
(4, 57344): (4, 256, 64),
(4, 65536): (4, 256, 64),
(4, 73728): (4, 256, 64),
(4, 81920): (4, 512, 128),
(4, 90112): (4, 512, 128),
(4, 98304): (4, 512, 128),
(4, 106496): (4, 512, 128),
(4, 114688): (4, 512, 128),
(4, 122880): (4, 512, 128),
(4, 131072): (4, 512, 128),
(8, 64): (4, 64, 64),
(8, 128): (2, 64, 64),
(8, 256): (2, 64, 64),
(8, 512): (4, 64, 64),
(8, 768): (2, 64, 64),
(8, 1024): (2, 64, 64),
(8, 1536): (2, 64, 64),
(8, 2048): (2, 64, 64),
(8, 3072): (4, 128, 64),
(8, 4096): (4, 128, 64),
(8, 6144): (4, 128, 64),
(8, 8192): (4, 128, 64),
(8, 12288): (2, 256, 64),
(8, 16384): (2, 256, 64),
(8, 24576): (4, 128, 64),
(8, 32768): (4, 256, 64),
(8, 40960): (4, 256, 64),
(8, 49152): (4, 512, 128),
(8, 57344): (4, 256, 64),
(8, 65536): (4, 512, 128),
(8, 73728): (4, 512, 128),
(8, 81920): (4, 512, 128),
(8, 90112): (4, 512, 128),
(8, 98304): (8, 512, 64),
(8, 106496): (8, 512, 64),
(8, 114688): (4, 1024, 128),
(8, 122880): (4, 512, 128),
(8, 131072): (8, 512, 64),
(16, 64): (2, 64, 64),
(16, 128): (2, 64, 64),
(16, 256): (2, 64, 64),
(16, 512): (2, 64, 64),
(16, 768): (2, 64, 64),
(16, 1024): (2, 64, 64),
(16, 1536): (4, 128, 64),
(16, 2048): (4, 128, 64),
(16, 3072): (2, 256, 64),
(16, 4096): (4, 128, 64),
(16, 6144): (2, 256, 64),
(16, 8192): (2, 256, 64),
(16, 12288): (4, 128, 64),
(16, 16384): (4, 256, 64),
(16, 24576): (4, 128, 64),
(16, 32768): (4, 512, 128),
(16, 40960): (4, 256, 64),
(16, 49152): (4, 256, 64),
(16, 57344): (4, 512, 128),
(16, 65536): (4, 1024, 128),
(16, 73728): (4, 1024, 128),
(16, 81920): (4, 512, 128),
(16, 90112): (4, 512, 128),
(16, 98304): (4, 512, 128),
(16, 106496): (4, 512, 128),
(16, 114688): (8, 512, 64),
(16, 122880): (8, 512, 64),
(16, 131072): (4, 2048, 128),
(32, 64): (2, 64, 64),
(32, 128): (2, 64, 64),
(32, 256): (2, 64, 64),
(32, 512): (2, 64, 64),
(32, 768): (4, 128, 64),
(32, 1024): (2, 128, 64),
(32, 1536): (4, 128, 64),
(32, 2048): (2, 256, 64),
(32, 3072): (2, 256, 64),
(32, 4096): (2, 256, 64),
(32, 6144): (4, 256, 64),
(32, 8192): (4, 128, 64),
(32, 12288): (4, 512, 128),
(32, 16384): (4, 512, 128),
(32, 24576): (4, 256, 64),
(32, 32768): (4, 512, 128),
(32, 40960): (4, 512, 128),
(32, 49152): (4, 512, 128),
(32, 57344): (4, 1024, 128),
(32, 65536): (4, 512, 128),
(32, 73728): (8, 512, 64),
(32, 81920): (8, 512, 64),
(32, 90112): (8, 512, 64),
(32, 98304): (8, 512, 64),
(32, 106496): (8, 512, 64),
(32, 114688): (8, 512, 64),
(32, 122880): (8, 1024, 64),
(32, 131072): (8, 512, 64),
(64, 64): (2, 64, 64),
(64, 128): (2, 64, 64),
(64, 256): (2, 64, 64),
(64, 512): (4, 128, 64),
(64, 768): (4, 128, 64),
(64, 1024): (2, 256, 64),
(64, 1536): (2, 256, 64),
(64, 2048): (4, 128, 64),
(64, 3072): (4, 256, 64),
(64, 4096): (4, 256, 64),
(64, 6144): (4, 128, 64),
(64, 8192): (4, 512, 128),
(64, 12288): (4, 256, 64),
(64, 16384): (4, 1024, 128),
(64, 24576): (4, 512, 128),
(64, 32768): (4, 512, 128),
(64, 40960): (4, 512, 128),
(64, 49152): (4, 1024, 128),
(64, 57344): (4, 512, 128),
(64, 65536): (4, 1024, 128),
(64, 73728): (8, 1024, 64),
(64, 81920): (8, 512, 64),
(64, 90112): (8, 1024, 64),
(64, 98304): (8, 1024, 64),
(64, 106496): (8, 1024, 64),
(64, 114688): (8, 512, 64),
(64, 122880): (8, 512, 64),
(64, 131072): (8, 512, 64),
(128, 64): (2, 64, 64),
(128, 128): (2, 64, 64),
(128, 256): (2, 128, 64),
(128, 512): (4, 128, 64),
(128, 768): (2, 256, 64),
(128, 1024): (2, 256, 64),
(128, 1536): (4, 256, 64),
(128, 2048): (4, 256, 64),
(128, 3072): (4, 128, 64),
(128, 4096): (4, 512, 128),
(128, 6144): (4, 256, 64),
(128, 8192): (4, 1024, 128),
(128, 12288): (4, 512, 128),
(128, 16384): (4, 512, 128),
(128, 24576): (4, 1024, 128),
(128, 32768): (4, 512, 128),
(128, 40960): (4, 1024, 128),
(128, 49152): (4, 2048, 128),
(128, 57344): (8, 512, 64),
(128, 65536): (8, 512, 64),
(128, 73728): (8, 1024, 64),
(128, 81920): (8, 1024, 64),
(128, 90112): (4, 4096, 128),
(128, 98304): (8, 1024, 64),
(128, 106496): (8, 1024, 64),
(128, 114688): (8, 1024, 64),
(128, 122880): (8, 1024, 64),
(128, 131072): (8, 1024, 64),
(256, 64): (2, 64, 64),
(256, 128): (4, 128, 64),
(256, 256): (4, 128, 64),
(256, 512): (2, 256, 64),
(256, 768): (4, 256, 64),
(256, 1024): (4, 256, 64),
(256, 1536): (4, 128, 64),
(256, 2048): (4, 512, 128),
(256, 3072): (4, 256, 64),
(256, 4096): (4, 128, 64),
(256, 6144): (4, 512, 128),
(256, 8192): (4, 512, 128),
(256, 12288): (4, 1024, 128),
(256, 16384): (4, 1024, 128),
(256, 24576): (4, 1024, 128),
(256, 32768): (4, 1024, 128),
(256, 40960): (4, 2048, 128),
(256, 49152): (4, 2048, 128),
(256, 57344): (8, 1024, 64),
(256, 65536): (8, 1024, 64),
(256, 73728): (8, 2048, 64),
(256, 81920): (8, 2048, 64),
(256, 90112): (8, 2048, 64),
(256, 98304): (8, 1024, 64),
(256, 106496): (8, 1024, 64),
(256, 114688): (8, 2048, 64),
(256, 122880): (8, 2048, 64),
(256, 131072): (8, 1024, 64),
(384, 64): (2, 64, 64),
(384, 128): (2, 128, 64),
(384, 256): (4, 128, 64),
(384, 512): (4, 256, 64),
(384, 768): (4, 256, 64),
(384, 1024): (4, 128, 64),
(384, 1536): (4, 512, 128),
(384, 2048): (4, 256, 64),
(384, 3072): (4, 1024, 128),
(384, 4096): (4, 512, 128),
(384, 6144): (4, 1024, 128),
(384, 8192): (4, 1024, 128),
(384, 12288): (4, 1024, 128),
(384, 16384): (4, 1024, 128),
(384, 24576): (4, 1024, 128),
(384, 32768): (4, 1024, 128),
(384, 40960): (4, 1024, 128),
(384, 49152): (8, 1024, 64),
(384, 57344): (8, 2048, 64),
(384, 65536): (8, 1024, 64),
(384, 73728): (8, 2048, 64),
(384, 81920): (8, 2048, 64),
(384, 90112): (8, 1024, 64),
(384, 98304): (8, 2048, 64),
(384, 106496): (8, 2048, 64),
(384, 114688): (8, 2048, 64),
(384, 122880): (8, 4096, 64),
(384, 131072): (8, 2048, 64),
(512, 64): (2, 64, 64),
(512, 128): (4, 128, 64),
(512, 256): (2, 256, 64),
(512, 512): (4, 256, 64),
(512, 768): (4, 128, 64),
(512, 1024): (4, 512, 128),
(512, 1536): (4, 256, 64),
(512, 2048): (4, 1024, 128),
(512, 3072): (4, 512, 128),
(512, 4096): (4, 512, 128),
(512, 6144): (4, 512, 128),
(512, 8192): (4, 1024, 128),
(512, 12288): (4, 1024, 128),
(512, 16384): (4, 1024, 128),
(512, 24576): (4, 1024, 128),
(512, 32768): (4, 1024, 128),
(512, 40960): (4, 2048, 128),
(512, 49152): (8, 1024, 64),
(512, 57344): (8, 2048, 64),
(512, 65536): (8, 1024, 64),
(512, 73728): (8, 2048, 64),
(512, 81920): (8, 2048, 64),
(512, 90112): (8, 2048, 64),
(512, 98304): (8, 2048, 64),
(512, 106496): (8, 2048, 64),
(512, 114688): (8, 4096, 64),
(512, 122880): (8, 4096, 64),
(512, 131072): (8, 2048, 64),
(640, 64): (2, 64, 64),
(640, 128): (4, 128, 64),
(640, 256): (4, 128, 64),
(640, 512): (4, 128, 64),
(640, 768): (4, 256, 64),
(640, 1024): (4, 256, 64),
(640, 1536): (4, 512, 128),
(640, 2048): (4, 512, 128),
(640, 3072): (4, 1024, 128),
(640, 4096): (4, 1024, 128),
(640, 6144): (4, 1024, 128),
(640, 8192): (4, 1024, 128),
(640, 12288): (4, 1024, 128),
(640, 16384): (4, 1024, 128),
(640, 24576): (4, 1024, 128),
(640, 32768): (4, 1024, 128),
(640, 40960): (4, 2048, 128),
(640, 49152): (8, 2048, 64),
(640, 57344): (8, 1024, 64),
(640, 65536): (8, 2048, 64),
(640, 73728): (8, 2048, 64),
(640, 81920): (8, 2048, 64),
(640, 90112): (8, 2048, 64),
(640, 98304): (8, 4096, 64),
(640, 106496): (8, 4096, 64),
(640, 114688): (8, 2048, 64),
(640, 122880): (8, 2048, 64),
(640, 131072): (8, 4096, 64),
(768, 64): (2, 64, 64),
(768, 128): (4, 128, 64),
(768, 256): (4, 256, 64),
(768, 512): (4, 128, 64),
(768, 768): (4, 256, 64),
(768, 1024): (4, 256, 64),
(768, 1536): (4, 512, 128),
(768, 2048): (4, 512, 128),
(768, 3072): (4, 1024, 128),
(768, 4096): (4, 1024, 128),
(768, 6144): (4, 1024, 128),
(768, 8192): (4, 1024, 128),
(768, 12288): (4, 1024, 128),
(768, 16384): (4, 1024, 128),
(768, 24576): (4, 1024, 128),
(768, 32768): (4, 1024, 128),
(768, 40960): (4, 2048, 128),
(768, 49152): (8, 2048, 64),
(768, 57344): (8, 2048, 64),
(768, 65536): (8, 2048, 64),
(768, 73728): (8, 2048, 64),
(768, 81920): (8, 4096, 64),
(768, 90112): (8, 2048, 64),
(768, 98304): (8, 4096, 64),
(768, 106496): (8, 4096, 64),
(768, 114688): (8, 2048, 64),
(768, 122880): (8, 4096, 64),
(768, 131072): (8, 4096, 64),
(896, 64): (8, 64, 64),
(896, 128): (4, 128, 64),
(896, 256): (4, 128, 64),
(896, 512): (4, 256, 64),
(896, 768): (4, 128, 64),
(896, 1024): (4, 256, 64),
(896, 1536): (4, 512, 128),
(896, 2048): (4, 512, 128),
(896, 3072): (4, 512, 128),
(896, 4096): (4, 1024, 128),
(896, 6144): (4, 1024, 128),
(896, 8192): (4, 1024, 128),
(896, 12288): (4, 1024, 128),
(896, 16384): (4, 2048, 128),
(896, 24576): (4, 1024, 128),
(896, 32768): (4, 1024, 128),
(896, 40960): (4, 2048, 128),
(896, 49152): (8, 1024, 64),
(896, 57344): (8, 2048, 64),
(896, 65536): (8, 2048, 64),
(896, 73728): (8, 4096, 64),
(896, 81920): (8, 2048, 64),
(896, 90112): (8, 2048, 64),
(896, 98304): (8, 2048, 64),
(896, 106496): (8, 2048, 64),
(896, 114688): (8, 4096, 64),
(896, 122880): (8, 4096, 64),
(896, 131072): (8, 4096, 64),
(1024, 64): (2, 64, 64),
(1024, 128): (4, 128, 64),
(1024, 256): (4, 256, 64),
(1024, 512): (4, 512, 128),
(1024, 768): (4, 256, 64),
(1024, 1024): (4, 1024, 128),
(1024, 1536): (4, 512, 128),
(1024, 2048): (4, 512, 128),
(1024, 3072): (4, 1024, 128),
(1024, 4096): (4, 1024, 128),
(1024, 6144): (4, 1024, 128),
(1024, 8192): (4, 1024, 128),
(1024, 12288): (4, 1024, 128),
(1024, 16384): (4, 1024, 128),
(1024, 24576): (4, 2048, 128),
(1024, 32768): (4, 1024, 128),
(1024, 40960): (4, 2048, 128),
(1024, 49152): (8, 2048, 64),
(1024, 57344): (8, 2048, 64),
(1024, 65536): (8, 2048, 64),
(1024, 73728): (8, 4096, 64),
(1024, 81920): (8, 4096, 64),
(1024, 90112): (8, 2048, 64),
(1024, 98304): (8, 4096, 64),
(1024, 106496): (8, 2048, 64),
(1024, 114688): (8, 4096, 64),
(1024, 122880): (8, 4096, 64),
(1024, 131072): (8, 4096, 64),
(1280, 64): (2, 64, 64),
(1280, 128): (4, 128, 64),
(1280, 256): (4, 128, 64),
(1280, 512): (4, 256, 64),
(1280, 768): (4, 128, 64),
(1280, 1024): (4, 512, 128),
(1280, 1536): (4, 512, 128),
(1280, 2048): (4, 512, 128),
(1280, 3072): (4, 1024, 128),
(1280, 4096): (4, 1024, 128),
(1280, 6144): (4, 1024, 128),
(1280, 8192): (4, 1024, 128),
(1280, 12288): (4, 2048, 128),
(1280, 16384): (4, 2048, 128),
(1280, 24576): (4, 1024, 128),
(1280, 32768): (4, 1024, 128),
(1280, 40960): (4, 2048, 128),
(1280, 49152): (8, 4096, 64),
(1280, 57344): (8, 2048, 64),
(1280, 65536): (8, 4096, 64),
(1280, 73728): (8, 4096, 64),
(1280, 81920): (8, 2048, 64),
(1280, 90112): (8, 4096, 64),
(1280, 98304): (8, 4096, 64),
(1280, 106496): (8, 4096, 64),
(1280, 114688): (8, 4096, 64),
(1280, 122880): (8, 4096, 64),
(1280, 131072): (8, 4096, 64),
(1536, 64): (2, 64, 64),
(1536, 128): (4, 128, 64),
(1536, 256): (4, 128, 64),
(1536, 512): (4, 256, 64),
(1536, 768): (4, 256, 64),
(1536, 1024): (4, 512, 128),
(1536, 1536): (4, 512, 128),
(1536, 2048): (4, 1024, 128),
(1536, 3072): (4, 1024, 128),
(1536, 4096): (4, 1024, 128),
(1536, 6144): (4, 1024, 128),
(1536, 8192): (4, 1024, 128),
(1536, 12288): (4, 2048, 128),
(1536, 16384): (4, 2048, 128),
(1536, 24576): (4, 2048, 128),
(1536, 32768): (4, 1024, 128),
(1536, 40960): (4, 2048, 128),
(1536, 49152): (8, 4096, 64),
(1536, 57344): (8, 2048, 64),
(1536, 65536): (8, 4096, 64),
(1536, 73728): (8, 4096, 64),
(1536, 81920): (8, 4096, 64),
(1536, 90112): (8, 4096, 64),
(1536, 98304): (8, 4096, 64),
(1536, 106496): (8, 4096, 64),
(1536, 114688): (8, 4096, 64),
(1536, 122880): (8, 8192, 64),
(1536, 131072): (8, 4096, 64),
(1792, 64): (2, 64, 64),
(1792, 128): (4, 128, 64),
(1792, 256): (4, 128, 64),
(1792, 512): (4, 256, 64),
(1792, 768): (4, 256, 64),
(1792, 1024): (4, 512, 128),
(1792, 1536): (4, 512, 128),
(1792, 2048): (4, 1024, 128),
(1792, 3072): (4, 1024, 128),
(1792, 4096): (4, 1024, 128),
(1792, 6144): (4, 1024, 128),
(1792, 8192): (4, 2048, 128),
(1792, 12288): (4, 2048, 128),
(1792, 16384): (4, 2048, 128),
(1792, 24576): (4, 2048, 128),
(1792, 32768): (4, 1024, 128),
(1792, 40960): (4, 2048, 128),
(1792, 49152): (8, 2048, 64),
(1792, 57344): (8, 4096, 64),
(1792, 65536): (8, 4096, 64),
(1792, 73728): (8, 4096, 64),
(1792, 81920): (8, 4096, 64),
(1792, 90112): (8, 4096, 64),
(1792, 98304): (8, 4096, 64),
(1792, 106496): (8, 4096, 64),
(1792, 114688): (8, 8192, 64),
(1792, 122880): (8, 8192, 64),
(1792, 131072): (8, 8192, 64),
(2048, 64): (2, 64, 64),
(2048, 128): (4, 128, 64),
(2048, 256): (4, 128, 64),
(2048, 512): (4, 512, 128),
(2048, 768): (4, 256, 64),
(2048, 1024): (4, 512, 128),
(2048, 1536): (4, 512, 128),
(2048, 2048): (4, 1024, 128),
(2048, 3072): (4, 1024, 128),
(2048, 4096): (4, 1024, 128),
(2048, 6144): (4, 1024, 128),
(2048, 8192): (4, 1024, 128),
(2048, 12288): (4, 2048, 128),
(2048, 16384): (4, 2048, 128),
(2048, 24576): (4, 2048, 128),
(2048, 32768): (4, 2048, 128),
(2048, 40960): (4, 2048, 128),
(2048, 49152): (8, 4096, 64),
(2048, 57344): (8, 4096, 64),
(2048, 65536): (8, 4096, 64),
(2048, 73728): (8, 4096, 64),
(2048, 81920): (8, 4096, 64),
(2048, 90112): (8, 4096, 64),
(2048, 98304): (8, 8192, 64),
(2048, 106496): (8, 4096, 64),
(2048, 114688): (8, 8192, 64),
(2048, 122880): (8, 4096, 64),
(2048, 131072): (8, 8192, 64),
(2560, 64): (4, 64, 64),
(2560, 128): (4, 128, 64),
(2560, 256): (4, 256, 64),
(2560, 512): (4, 512, 128),
(2560, 768): (4, 256, 64),
(2560, 1024): (4, 512, 128),
(2560, 1536): (4, 512, 128),
(2560, 2048): (4, 1024, 128),
(2560, 3072): (4, 1024, 128),
(2560, 4096): (4, 1024, 128),
(2560, 6144): (4, 2048, 128),
(2560, 8192): (4, 2048, 128),
(2560, 12288): (4, 2048, 128),
(2560, 16384): (4, 2048, 128),
(2560, 24576): (4, 2048, 128),
(2560, 32768): (4, 1024, 128),
(2560, 40960): (4, 2048, 128),
(2560, 49152): (8, 4096, 64),
(2560, 57344): (8, 4096, 64),
(2560, 65536): (8, 4096, 64),
(2560, 73728): (8, 4096, 64),
(2560, 81920): (8, 4096, 64),
(2560, 90112): (8, 8192, 64),
(2560, 98304): (8, 4096, 64),
(2560, 106496): (8, 8192, 64),
(2560, 114688): (8, 4096, 64),
(2560, 122880): (8, 8192, 64),
(2560, 131072): (8, 4096, 64),
(3072, 64): (4, 64, 64),
(3072, 128): (4, 128, 64),
(3072, 256): (4, 256, 64),
(3072, 512): (4, 512, 128),
(3072, 768): (8, 256, 64),
(3072, 1024): (4, 1024, 128),
(3072, 1536): (4, 512, 128),
(3072, 2048): (4, 1024, 128),
(3072, 3072): (4, 1024, 128),
(3072, 4096): (4, 1024, 128),
(3072, 6144): (4, 2048, 128),
(3072, 8192): (4, 2048, 128),
(3072, 12288): (4, 2048, 128),
(3072, 16384): (4, 2048, 128),
(3072, 24576): (4, 2048, 128),
(3072, 32768): (4, 2048, 128),
(3072, 40960): (8, 2048, 64),
(3072, 49152): (8, 4096, 64),
(3072, 57344): (8, 4096, 64),
(3072, 65536): (8, 4096, 64),
(3072, 73728): (8, 8192, 64),
(3072, 81920): (8, 4096, 64),
(3072, 90112): (8, 8192, 64),
(3072, 98304): (8, 8192, 64),
(3072, 106496): (8, 4096, 64),
(3072, 114688): (8, 8192, 64),
(3072, 122880): (8, 8192, 64),
(3072, 131072): (8, 8192, 64),
(3584, 64): (4, 64, 64),
(3584, 128): (4, 128, 64),
(3584, 256): (4, 128, 64),
(3584, 512): (4, 512, 128),
(3584, 768): (4, 256, 64),
(3584, 1024): (4, 512, 128),
(3584, 1536): (4, 512, 128),
(3584, 2048): (4, 1024, 128),
(3584, 3072): (4, 1024, 128),
(3584, 4096): (4, 2048, 128),
(3584, 6144): (4, 2048, 128),
(3584, 8192): (4, 2048, 128),
(3584, 12288): (4, 2048, 128),
(3584, 16384): (4, 2048, 128),
(3584, 24576): (4, 2048, 128),
(3584, 32768): (4, 2048, 128),
(3584, 40960): (8, 4096, 64),
(3584, 49152): (8, 4096, 64),
(3584, 57344): (8, 4096, 64),
(3584, 65536): (8, 4096, 64),
(3584, 73728): (8, 8192, 64),
(3584, 81920): (8, 4096, 64),
(3584, 90112): (8, 4096, 64),
(3584, 98304): (8, 8192, 64),
(3584, 106496): (8, 8192, 64),
(3584, 114688): (8, 8192, 64),
(3584, 122880): (8, 8192, 64),
(3584, 131072): (8, 8192, 64),
(4096, 64): (4, 64, 64),
(4096, 128): (4, 128, 64),
(4096, 256): (4, 128, 64),
(4096, 512): (4, 512, 128),
(4096, 768): (8, 256, 64),
(4096, 1024): (4, 512, 128),
(4096, 1536): (4, 512, 128),
(4096, 2048): (4, 1024, 128),
(4096, 3072): (4, 1024, 128),
(4096, 4096): (4, 1024, 128),
(4096, 6144): (4, 2048, 128),
(4096, 8192): (4, 2048, 128),
(4096, 12288): (4, 2048, 128),
(4096, 16384): (4, 2048, 128),
(4096, 24576): (4, 2048, 128),
(4096, 32768): (4, 2048, 128),
(4096, 40960): (8, 4096, 64),
(4096, 49152): (8, 4096, 64),
(4096, 57344): (8, 4096, 64),
(4096, 65536): (8, 4096, 64),
(4096, 73728): (8, 8192, 64),
(4096, 81920): (8, 4096, 64),
(4096, 90112): (8, 4096, 64),
(4096, 98304): (8, 8192, 64),
(4096, 106496): (8, 8192, 64),
(4096, 114688): (8, 8192, 64),
(4096, 122880): (8, 8192, 64),
(4096, 131072): (8, 8192, 64),
}
def get_tuned_config(m: int, n: int) -> Tuple[int, int, int]:
"""Lookup tuned config for (m, n) using floor. Returns (m_split, blk_n1, blk_n2)."""
if m <= M_REPR_TABLE[0]:
m_repr = M_REPR_TABLE[0]
elif m >= M_REPR_TABLE[-1]:
m_repr = M_REPR_TABLE[-1]
else:
idx = bisect.bisect_right(M_REPR_TABLE, m) - 1
m_repr = M_REPR_TABLE[idx]
if n <= N_REPR_TABLE[0]:
n_repr = N_REPR_TABLE[0]
elif n >= N_REPR_TABLE[-1]:
n_repr = N_REPR_TABLE[-1]
else:
idx = bisect.bisect_right(N_REPR_TABLE, n) - 1
n_repr = N_REPR_TABLE[idx]
return CONFIG_MAP[(m_repr, n_repr)]
from typing import Optional, Tuple
import functools
import tilelang
import tilelang.language as T
import torch
tilelang.set_log_level("WARNING")
cu_count = torch.cuda.get_device_properties("cuda").multi_processor_count
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
tilelang.PassConfigKey.TL_DISABLE_DATA_RACE_CHECK: True,
}
BF16 = "bfloat16"
FP8 = "float8_e4m3"
FP32 = "float32"
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
):
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale else 2
blk_m = 32
group_size = 128
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
s_local = T.alloc_fragment((blk_m,), scale_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = s_local[i]
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), "Input tensor must be contiguous"
assert (
x.size(-1) % block_size == 0
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
N = x.size(-1)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
return y, s
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(
h: int, d: int, m_split: int, blk_n1: int, blk_n2: int, disable_buffer_ops: bool = False, threads: int = 256, clear_accum: bool = True
):
b, m, n = T.symbolic("b"), T.symbolic("m"), T.symbolic("n")
# if m_split * h > 128, use Square policy to avoid register spill
gemm_policy = T.GemmWarpPolicy.FullRow if m_split * h <= 128 else T.GemmWarpPolicy.Square
@T.prim_func
def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, T.ceildiv(n, blk_n1), m, threads=threads) as (
i_b, i1_n, i_m_block
):
if disable_buffer_ops:
T.disable_buffer_ops(o)
m_start = i_m_block
q_smem = T.alloc_shared((h, d), FP8)
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.annotate_layout({
q_smem: tilelang.layout.make_hcu_swizzled_layout(q_smem, major_pack=1),
k_smem: tilelang.layout.make_hcu_swizzled_layout(k_smem, major_pack=1),
})
q_frag = T.alloc_fragment((h, d), FP8)
q_s_frag = T.alloc_fragment(h, FP32)
k_frag = T.alloc_fragment((blk_n2, d), FP8)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
logits = T.alloc_fragment((blk_n2, h), FP32)
logits_sum = T.alloc_fragment(blk_n2, FP32)
T.copy(q[i_b, m_start, 0, 0], q_smem)
T.copy(q_smem, q_frag)
T.copy(q_s[i_b, m_start, 0], q_s_frag)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=0):
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
T.clear(logits)
T.copy(k_smem, k_frag)
T.gemm(k_frag, q_frag, logits, transpose_A=False, transpose_B=True, policy=gemm_policy)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
T.reduce_sum(logits, logits_sum, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum, o[i_b, m_start, i1_n * blk_n1 + i2_n * blk_n2])
@T.prim_func
def fp8_index_kernel_1(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, T.ceildiv(n, blk_n1), T.ceildiv(m, 2), threads=threads) as (
i_b, i1_n, i_m_block
):
if disable_buffer_ops:
T.disable_buffer_ops(o)
m_start = i_m_block * 2
q_smem0 = T.alloc_shared((h, d), FP8)
q_smem1 = T.alloc_shared((h, d), FP8)
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.annotate_layout({
q_smem0: tilelang.layout.make_hcu_swizzled_layout(q_smem0, major_pack=1),
q_smem1: tilelang.layout.make_hcu_swizzled_layout(q_smem1, major_pack=1),
k_smem: tilelang.layout.make_hcu_swizzled_layout(k_smem, major_pack=1),
})
q_frag0 = T.alloc_fragment((h, d), FP8)
q_frag1 = T.alloc_fragment((h, d), FP8)
q_s_frag0 = T.alloc_fragment(h, FP32)
q_s_frag1 = T.alloc_fragment(h, FP32)
k_frag = T.alloc_fragment((blk_n2, d), FP8)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
logits0 = T.alloc_fragment((blk_n2, h), FP32)
logits1 = T.alloc_fragment((blk_n2, h), FP32)
logits_sum0 = T.alloc_fragment(blk_n2, FP32)
logits_sum1 = T.alloc_fragment(blk_n2, FP32)
T.copy(q[i_b, m_start, 0, 0], q_smem0)
T.copy(q[i_b, m_start + 1, 0, 0], q_smem1)
T.copy(q_smem0, q_frag0)
T.copy(q_smem1, q_frag1)
T.copy(q_s[i_b, m_start, 0], q_s_frag0)
T.copy(q_s[i_b, m_start + 1, 0], q_s_frag1)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=0):
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
T.clear(logits0)
T.clear(logits1)
T.copy(k_smem, k_frag)
T.gemm(k_frag, q_frag0, logits0, transpose_A=False, transpose_B=True, policy=gemm_policy)
T.gemm(k_frag, q_frag1, logits1, transpose_A=False, transpose_B=True, policy=gemm_policy)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits0[i3_n, i_h] = T.max(logits0[i3_n, i_h], 0) * q_s_frag0[i_h]
logits1[i3_n, i_h] = T.max(logits1[i3_n, i_h], 0) * q_s_frag1[i_h]
T.reduce_sum(logits0, logits_sum0, dim=1)
T.reduce_sum(logits1, logits_sum1, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum0[i3_n] *= k_s_frag[i3_n]
logits_sum1[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum0, o[i_b, m_start, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum1, o[i_b, m_start + 1, i1_n * blk_n1 + i2_n * blk_n2])
@T.prim_func
def fp8_index_kernel_2(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, T.ceildiv(n, blk_n1), T.ceildiv(m, 4), threads=threads) as (
i_b, i1_n, i_m_block
):
if disable_buffer_ops:
T.disable_buffer_ops(o)
m_start = i_m_block * 4
q_smem0 = T.alloc_shared((h, d), FP8)
q_smem1 = T.alloc_shared((h, d), FP8)
q_smem2 = T.alloc_shared((h, d), FP8)
q_smem3 = T.alloc_shared((h, d), FP8)
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.annotate_layout({
q_smem0: tilelang.layout.make_hcu_swizzled_layout(q_smem0, major_pack=1),
q_smem1: tilelang.layout.make_hcu_swizzled_layout(q_smem1, major_pack=1),
q_smem2: tilelang.layout.make_hcu_swizzled_layout(q_smem2, major_pack=1),
q_smem3: tilelang.layout.make_hcu_swizzled_layout(q_smem3, major_pack=1),
k_smem: tilelang.layout.make_hcu_swizzled_layout(k_smem, major_pack=1),
})
q_frag0 = T.alloc_fragment((h, d), FP8)
q_frag1 = T.alloc_fragment((h, d), FP8)
q_frag2 = T.alloc_fragment((h, d), FP8)
q_frag3 = T.alloc_fragment((h, d), FP8)
q_s_frag0 = T.alloc_fragment(h, FP32)
q_s_frag1 = T.alloc_fragment(h, FP32)
q_s_frag2 = T.alloc_fragment(h, FP32)
q_s_frag3 = T.alloc_fragment(h, FP32)
T.copy(q[i_b, m_start, 0, 0], q_smem0)
T.copy(q[i_b, m_start + 1, 0, 0], q_smem1)
T.copy(q[i_b, m_start + 2, 0, 0], q_smem2)
T.copy(q[i_b, m_start + 3, 0, 0], q_smem3)
T.copy(q_smem0, q_frag0)
T.copy(q_smem1, q_frag1)
T.copy(q_smem2, q_frag2)
T.copy(q_smem3, q_frag3)
T.copy(q_s[i_b, m_start, 0], q_s_frag0)
T.copy(q_s[i_b, m_start + 1, 0], q_s_frag1)
T.copy(q_s[i_b, m_start + 2, 0], q_s_frag2)
T.copy(q_s[i_b, m_start + 3, 0], q_s_frag3)
k_frag = T.alloc_fragment((blk_n2, d), FP8)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
logits0 = T.alloc_fragment((blk_n2, h), FP32)
logits1 = T.alloc_fragment((blk_n2, h), FP32)
logits2 = T.alloc_fragment((blk_n2, h), FP32)
logits3 = T.alloc_fragment((blk_n2, h), FP32)
logits_sum0 = T.alloc_fragment(blk_n2, FP32)
logits_sum1 = T.alloc_fragment(blk_n2, FP32)
logits_sum2 = T.alloc_fragment(blk_n2, FP32)
logits_sum3 = T.alloc_fragment(blk_n2, FP32)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=0):
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
T.clear(logits0)
T.clear(logits1)
T.clear(logits2)
T.clear(logits3)
T.copy(k_smem, k_frag)
T.gemm(k_frag, q_frag0, logits0, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag1, logits1, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag2, logits2, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag3, logits3, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits0[i3_n, i_h] = T.max(logits0[i3_n, i_h], 0) * q_s_frag0[i_h]
logits1[i3_n, i_h] = T.max(logits1[i3_n, i_h], 0) * q_s_frag1[i_h]
logits2[i3_n, i_h] = T.max(logits2[i3_n, i_h], 0) * q_s_frag2[i_h]
logits3[i3_n, i_h] = T.max(logits3[i3_n, i_h], 0) * q_s_frag3[i_h]
T.reduce_sum(logits0, logits_sum0, dim=1)
T.reduce_sum(logits1, logits_sum1, dim=1)
T.reduce_sum(logits2, logits_sum2, dim=1)
T.reduce_sum(logits3, logits_sum3, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum0[i3_n] *= k_s_frag[i3_n]
logits_sum1[i3_n] *= k_s_frag[i3_n]
logits_sum2[i3_n] *= k_s_frag[i3_n]
logits_sum3[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum0, o[i_b, m_start, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum1, o[i_b, m_start + 1, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum2, o[i_b, m_start + 2, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum3, o[i_b, m_start + 3, i1_n * blk_n1 + i2_n * blk_n2])
@T.prim_func
def fp8_index_kernel_3(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, T.ceildiv(n, blk_n1), T.ceildiv(m, 8), threads=threads) as (
i_b, i1_n, i_m_block
):
if disable_buffer_ops:
T.disable_buffer_ops(o)
m_start = i_m_block * 8
q_smem0 = T.alloc_shared((h, d), FP8)
q_smem1 = T.alloc_shared((h, d), FP8)
q_smem2 = T.alloc_shared((h, d), FP8)
q_smem3 = T.alloc_shared((h, d), FP8)
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.annotate_layout({
q_smem0: tilelang.layout.make_hcu_swizzled_layout(q_smem0, major_pack=1),
q_smem1: tilelang.layout.make_hcu_swizzled_layout(q_smem1, major_pack=1),
q_smem2: tilelang.layout.make_hcu_swizzled_layout(q_smem2, major_pack=1),
q_smem3: tilelang.layout.make_hcu_swizzled_layout(q_smem3, major_pack=1),
k_smem: tilelang.layout.make_hcu_swizzled_layout(k_smem, major_pack=1),
})
q_pre_frag0 = T.alloc_fragment((h, d), FP8)
q_pre_frag1 = T.alloc_fragment((h, d), FP8)
q_pre_frag2 = T.alloc_fragment((h, d), FP8)
q_pre_frag3 = T.alloc_fragment((h, d), FP8)
q_pre_frag4 = T.alloc_fragment((h, d), FP8)
q_pre_frag5 = T.alloc_fragment((h, d), FP8)
q_pre_frag6 = T.alloc_fragment((h, d), FP8)
q_pre_frag7 = T.alloc_fragment((h, d), FP8)
q_frag0 = T.alloc_fragment((h, d), FP8)
q_frag1 = T.alloc_fragment((h, d), FP8)
q_frag2 = T.alloc_fragment((h, d), FP8)
q_frag3 = T.alloc_fragment((h, d), FP8)
q_frag4 = T.alloc_fragment((h, d), FP8)
q_frag5 = T.alloc_fragment((h, d), FP8)
q_frag6 = T.alloc_fragment((h, d), FP8)
q_frag7 = T.alloc_fragment((h, d), FP8)
q_s_frag0 = T.alloc_fragment(h, FP32)
q_s_frag1 = T.alloc_fragment(h, FP32)
q_s_frag2 = T.alloc_fragment(h, FP32)
q_s_frag3 = T.alloc_fragment(h, FP32)
q_s_frag4 = T.alloc_fragment(h, FP32)
q_s_frag5 = T.alloc_fragment(h, FP32)
q_s_frag6 = T.alloc_fragment(h, FP32)
q_s_frag7 = T.alloc_fragment(h, FP32)
k_frag = T.alloc_fragment((blk_n2, d), FP8)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
logits0 = T.alloc_fragment((blk_n2, h), FP32)
logits1 = T.alloc_fragment((blk_n2, h), FP32)
logits2 = T.alloc_fragment((blk_n2, h), FP32)
logits3 = T.alloc_fragment((blk_n2, h), FP32)
logits4 = T.alloc_fragment((blk_n2, h), FP32)
logits5 = T.alloc_fragment((blk_n2, h), FP32)
logits6 = T.alloc_fragment((blk_n2, h), FP32)
logits7 = T.alloc_fragment((blk_n2, h), FP32)
logits_sum0 = T.alloc_fragment(blk_n2, FP32)
logits_sum1 = T.alloc_fragment(blk_n2, FP32)
logits_sum2 = T.alloc_fragment(blk_n2, FP32)
logits_sum3 = T.alloc_fragment(blk_n2, FP32)
logits_sum4 = T.alloc_fragment(blk_n2, FP32)
logits_sum5 = T.alloc_fragment(blk_n2, FP32)
logits_sum6 = T.alloc_fragment(blk_n2, FP32)
logits_sum7 = T.alloc_fragment(blk_n2, FP32)
T.copy(q[i_b, m_start, 0, 0], q_pre_frag0)
T.copy(q[i_b, m_start + 1, 0, 0], q_pre_frag1)
T.copy(q[i_b, m_start + 2, 0, 0], q_pre_frag2)
T.copy(q[i_b, m_start + 3, 0, 0], q_pre_frag3)
T.copy(q[i_b, m_start + 4, 0, 0], q_pre_frag4)
T.copy(q[i_b, m_start + 5, 0, 0], q_pre_frag5)
T.copy(q[i_b, m_start + 6, 0, 0], q_pre_frag6)
T.copy(q[i_b, m_start + 7, 0, 0], q_pre_frag7)
T.copy(q_s[i_b, m_start, 0], q_s_frag0)
T.copy(q_s[i_b, m_start + 1, 0], q_s_frag1)
T.copy(q_s[i_b, m_start + 2, 0], q_s_frag2)
T.copy(q_s[i_b, m_start + 3, 0], q_s_frag3)
T.copy(q_s[i_b, m_start + 4, 0], q_s_frag4)
T.copy(q_s[i_b, m_start + 5, 0], q_s_frag5)
T.copy(q_s[i_b, m_start + 6, 0], q_s_frag6)
T.copy(q_s[i_b, m_start + 7, 0], q_s_frag7)
T.copy(q_pre_frag0, q_smem0)
T.copy(q_pre_frag1, q_smem1)
T.copy(q_pre_frag2, q_smem2)
T.copy(q_pre_frag3, q_smem3)
T.copy(q_smem0, q_frag0)
T.copy(q_smem1, q_frag1)
T.copy(q_smem2, q_frag2)
T.copy(q_smem3, q_frag3)
T.copy(q_pre_frag4, q_smem0)
T.copy(q_pre_frag5, q_smem1)
T.copy(q_pre_frag6, q_smem2)
T.copy(q_pre_frag7, q_smem3)
T.copy(q_smem0, q_frag4)
T.copy(q_smem1, q_frag5)
T.copy(q_smem2, q_frag6)
T.copy(q_smem3, q_frag7)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=0):
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
T.clear(logits0)
T.clear(logits1)
T.clear(logits2)
T.clear(logits3)
T.clear(logits4)
T.clear(logits5)
T.clear(logits6)
T.clear(logits7)
T.copy(k_smem, k_frag)
T.gemm(k_frag, q_frag0, logits0, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag1, logits1, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag2, logits2, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag3, logits3, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag4, logits4, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag5, logits5, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag6, logits6, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
T.gemm(k_frag, q_frag7, logits7, transpose_A=False, transpose_B=True, k_pack=1, policy=gemm_policy)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits0[i3_n, i_h] = T.max(logits0[i3_n, i_h], 0) * q_s_frag0[i_h]
logits1[i3_n, i_h] = T.max(logits1[i3_n, i_h], 0) * q_s_frag1[i_h]
logits2[i3_n, i_h] = T.max(logits2[i3_n, i_h], 0) * q_s_frag2[i_h]
logits3[i3_n, i_h] = T.max(logits3[i3_n, i_h], 0) * q_s_frag3[i_h]
logits4[i3_n, i_h] = T.max(logits4[i3_n, i_h], 0) * q_s_frag4[i_h]
logits5[i3_n, i_h] = T.max(logits5[i3_n, i_h], 0) * q_s_frag5[i_h]
logits6[i3_n, i_h] = T.max(logits6[i3_n, i_h], 0) * q_s_frag6[i_h]
logits7[i3_n, i_h] = T.max(logits7[i3_n, i_h], 0) * q_s_frag7[i_h]
T.reduce_sum(logits0, logits_sum0, dim=1)
T.reduce_sum(logits1, logits_sum1, dim=1)
T.reduce_sum(logits2, logits_sum2, dim=1)
T.reduce_sum(logits3, logits_sum3, dim=1)
T.reduce_sum(logits4, logits_sum4, dim=1)
T.reduce_sum(logits5, logits_sum5, dim=1)
T.reduce_sum(logits6, logits_sum6, dim=1)
T.reduce_sum(logits7, logits_sum7, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum0[i3_n] *= k_s_frag[i3_n]
logits_sum1[i3_n] *= k_s_frag[i3_n]
logits_sum2[i3_n] *= k_s_frag[i3_n]
logits_sum3[i3_n] *= k_s_frag[i3_n]
logits_sum4[i3_n] *= k_s_frag[i3_n]
logits_sum5[i3_n] *= k_s_frag[i3_n]
logits_sum6[i3_n] *= k_s_frag[i3_n]
logits_sum7[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum0, o[i_b, m_start, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum1, o[i_b, m_start + 1, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum2, o[i_b, m_start + 2, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum3, o[i_b, m_start + 3, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum4, o[i_b, m_start + 4, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum5, o[i_b, m_start + 5, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum6, o[i_b, m_start + 6, i1_n * blk_n1 + i2_n * blk_n2])
T.copy(logits_sum7, o[i_b, m_start + 7, i1_n * blk_n1 + i2_n * blk_n2])
if m_split == 1:
return fp8_index_kernel_
elif m_split == 2:
return fp8_index_kernel_1
elif m_split == 4:
return fp8_index_kernel_2
else:
return fp8_index_kernel_3
@functools.lru_cache(maxsize=64)
def _get_config_module(h: int, d: int, cu_count: int):
"""Get config module for (h, d, cu_count). Returns None if not found."""
config_module_name = f"fp8_index_tuned_config_h{h}_d{d}_cu{cu_count}"
try:
return __import__(f"aiter.ops.tilelang.configs.fp8_index.{config_module_name}", fromlist=["get_tuned_config"])
except (ImportError, AttributeError):
return None
@functools.lru_cache(maxsize=128)
def _get_fp8_index_kernel(
h: int,
d: int,
m_split: int,
blk_n1: int,
blk_n2: int,
disable_buffer_ops: bool = False,
threads: int = 256,
clear_accum: bool = True,
):
"""Cached kernel creation. Dispatches to fp8_index_kernel_ / _1 / _2 based on m_split."""
assert m_split in (1, 2, 4, 8), "m_split must be 1, 2, 4, or 8"
print(
f"[fp8_index] kernel config: h={h} d={d} m_split={m_split} blk_n1={blk_n1} blk_n2={blk_n2} "
f"threads={threads} clear_accum={clear_accum} disable_buffer_ops={disable_buffer_ops}"
)
return fp8_index_kernel(
h, d, m_split, blk_n1, blk_n2, disable_buffer_ops, threads=threads, clear_accum=clear_accum
)
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor:
"""
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
b, m, h, d = q.shape
n = k.shape[1]
# Use tuned config; fallback to default if not found (run tune_fp8_index.py to generate)
mod = _get_config_module(h, d, cu_count)
if mod is not None:
m_split, blk_n1, blk_n2 = mod.get_tuned_config(m, n)
else:
m_split, blk_n1, blk_n2 = 1, 512, 128
disable_buffer_ops = False
if b * m * n * 4 >= 4294967296:
disable_buffer_ops = True
kernel = _get_fp8_index_kernel(
h, d, m_split, blk_n1, blk_n2, disable_buffer_ops, threads=256, clear_accum=False
)
return kernel(q, q_s, k, k_s)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment