Commit 85762c1a authored by Xiaowei.zhang's avatar Xiaowei.zhang
Browse files

Init the main branch for aiter

parent ae0b3521
Pipeline #3505 canceled with stages
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 98
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 183
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 146
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 160
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 42
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 42
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 46
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 46
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 86
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 98
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 183
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 146
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 160
}
}
\ No newline at end of file
# SPDX-License-Identifier: MIT
\ No newline at end of file
# SPDX-License-Identifier: MIT
from torch import Tensor
from ..jit.core import compile_ops
MD_NAME = "module_activation"
@compile_ops("module_activation")
def silu_and_mul(out: Tensor, input: Tensor) -> None: ...
@compile_ops("module_activation")
def scaled_silu_and_mul(out: Tensor, input: Tensor, scale: Tensor) -> None: ...
@compile_ops("module_activation")
def gelu_and_mul(out: Tensor, input: Tensor) -> None: ...
@compile_ops("module_activation")
def gelu_tanh_and_mul(out: Tensor, input: Tensor) -> None: ...
# SPDX-License-Identifier: MIT
from torch import Tensor
from ..jit.core import compile_ops
MD_NAME = "module_aiter_operator"
@compile_ops("module_aiter_operator")
def add(input: Tensor, other: Tensor) -> Tensor: ...
@compile_ops("module_aiter_operator")
def sub(input: Tensor, other: Tensor) -> Tensor: ...
@compile_ops("module_aiter_operator")
def mul(input: Tensor, other: Tensor) -> Tensor: ...
@compile_ops("module_aiter_operator")
def div(input: Tensor, other: Tensor) -> Tensor: ...
@compile_ops("module_aiter_operator")
def add_(input: Tensor, other: Tensor) -> Tensor: ...
@compile_ops("module_aiter_operator")
def sub_(input: Tensor, other: Tensor) -> Tensor: ...
@compile_ops("module_aiter_operator")
def mul_(input: Tensor, other: Tensor) -> Tensor: ...
@compile_ops("module_aiter_operator")
def div_(input: Tensor, other: Tensor) -> Tensor: ...
@compile_ops("module_aiter_unary")
def sigmoid(input: Tensor) -> Tensor: ...
@compile_ops("module_aiter_unary")
def tanh(input: Tensor) -> Tensor: ...
# SPDX-License-Identifier: MIT
import torch
from typing import Optional
from ..jit.core import (
compile_ops,
)
MD_NAME = "module_attention"
@compile_ops("module_attention")
def pa_fwd_naive(
# [num_seqs, num_heads, head_size]
query: torch.Tensor,
# [num_blocks, num_kv_heads, head_size/x, block_size, x]
key_cache: torch.Tensor,
# [num_blocks, num_kv_heads, head_size, block_size]
value_cache: torch.Tensor,
# [num_seqs, max_num_blocks_per_seq]
block_tables: torch.Tensor,
# [num_seqs]
context_lens: torch.Tensor,
k_dequant_scales: torch.Tensor,
v_dequant_scales: torch.Tensor,
max_seq_len: int,
num_kv_heads: int,
scale_s: float,
scale_k: float,
scale_v: float,
block_size: int,
quant_algo: int,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor: ...
@compile_ops("module_attention_asm")
def pa_fwd_asm(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_num_blocks: int,
K_QScale: Optional[torch.Tensor],
V_QScale: Optional[torch.Tensor],
out_: Optional[torch.Tensor] = None,
high_precision: Optional[
int
] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
) -> torch.Tensor: ...
@compile_ops("module_pa")
def paged_attention_rocm(
out: torch.Tensor,
exp_sums: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
fp8_out_scale: Optional[torch.Tensor],
partition_size: int,
): ...
@compile_ops("module_pa_ragged")
def paged_attention_ragged(
out: torch.Tensor,
workspace_buffer: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_num_partitions: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_cache_layout: str,
logits_soft_cap: float,
k_scale: float,
v_scale: float,
fp8_out_scale: Optional[torch.Tensor],
partition_size: int,
): ...
MD_NAME = "module_mla_asm"
@compile_ops(MD_NAME)
def mla_decode_stage1_asm_fwd(
# [num_seqs, num_heads, head_size]
Q: torch.Tensor,
# [num_page, page_size, num_kv_heads, kv_lora_rank + qk_rope_head_dim]
KV: torch.Tensor,
# [batch_size+1]
qo_indptr: torch.Tensor,
# [batch_size+1]
kv_indptr: torch.Tensor,
# [num_page_used]
kv_page_indices: torch.Tensor,
# [batch_size]
kv_last_page_lens: torch.Tensor,
max_seqlen_q: int,
softmax_scale: float,
# [batch_size, num_kv_splits, num_heads, v_head_dim]
splitData: torch.Tensor,
# [batch_size, num_kv_splits, num_heads, 1]
splitLse: torch.Tensor,
): ...
@compile_ops(MD_NAME)
def mla_prefill_asm_fwd(
# [num_seqs, num_heads, head_size]
Q: torch.Tensor,
# [num_page, page_size, num_kv_heads, kv_lora_rank + qk_rope_head_dim]
KV: torch.Tensor,
# [batch_size+1]
qo_indptr: torch.Tensor,
# [batch_size+1]
kv_indptr: torch.Tensor,
# [num_page_used]
kv_page_indices: torch.Tensor,
# [batch_size]
kv_last_page_lens: torch.Tensor,
max_seqlen_q: int,
softmax_scale: float,
# [batch_size, num_kv_splits, num_heads, v_head_dim]
splitData: torch.Tensor,
# [batch_size, num_kv_splits, num_heads, 1]
splitLse: torch.Tensor,
): ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from typing import Optional
from ..jit.core import compile_ops
MD_NAME = "module_awq_dq_asm"
@compile_ops("module_awq_dq_asm")
def awq_dq_asm(
out: Tensor,
mat1: Tensor,
zero: Optional[Tensor] = None,
scalar: Optional[Tensor] = None,
)->None: ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from typing import Optional
from ..jit.core import compile_ops
MD_NAME = "module_awq_gemm_asm"
@compile_ops("module_awq_gemm_asm")
def awq_gemm_asm(
out: Tensor,
mat1: Tensor,
mat2: Tensor,
zero: Optional[Tensor] = None,
scalar: Optional[Tensor] = None,
)->None: ...
@compile_ops("module_awq_gemm_asm")
def awq_gemm_asm_tuning(
out: Tensor,
mat1: Tensor,
mat2: Tensor,
zero: Optional[Tensor] = None,
scalar: Optional[Tensor] = None,
solutionid: int = 0,
jsonfile: str = None,
)->None: ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from typing import Optional
import functools
import pandas as pd
from ..jit.core import (
compile_ops,
AITER_CORE_DIR,
)
from ..utility import dtypes
from ..jit.utils.chip_info import get_cu_num
@compile_ops("module_batched_gemm_a8w8", fc_name="batched_gemm_a8w8")
def batched_gemm_a8w8(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
out: Tensor,
bias: Optional[Tensor] = None,
splitK=0,
): ...
@functools.lru_cache(maxsize=1024)
def compute_batched_gemm_SplitK(
M: int, N: int, K: int, tile_m: int, tile_n: int, tile_k: int
):
cu_num = get_cu_num()
tile_num = ((M + tile_m - 1) // tile_m) * ((N + tile_n - 1) // tile_n)
cusPerTile = cu_num / tile_num
splitK = 0
while cusPerTile >= pow(2, splitK + 1) and (pow(2, splitK + 1) * tile_k) < 2 * K:
splitK += 1
return splitK
@functools.lru_cache(maxsize=1024)
def get_CKBatchedGEMM_config(
B: int,
M: int,
N: int,
K: int,
):
if not hasattr(get_CKBatchedGEMM_config, "ck_batched_gemm_dict"):
ck_batched_gemm_dict = pd.read_csv(
f"{AITER_CORE_DIR}/aiter/configs/a8w8_tuned_batched_gemm.csv"
).drop_duplicates()
get_CKBatchedGEMM_config.ck_batched_gemm_dict = ck_batched_gemm_dict.set_index(
["B", "M", "N", "K"]
).to_dict("index")
config = get_CKBatchedGEMM_config.ck_batched_gemm_dict.get((B, M, N, K), None)
if config != None:
mnk = config["kernelName"].split("_")[3].split("x")[1:]
config["tile_m"] = int(mnk[0])
config["tile_n"] = int(mnk[1])
config["tile_k"] = int(mnk[2])
return config
def batched_gemm_a8w8_CK(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
bias: Optional[Tensor] = None,
dtype=dtypes.bf16,
splitK: Optional[int] = None,
):
assert dtype in [
dtypes.bf16,
dtypes.fp16,
], f"Output {dtype=} is currently not supported in batched_gemm_a8w8"
b = XQ.shape[0]
m = XQ.shape[1]
n = WQ.shape[1]
k = XQ.shape[2]
ck_config = get_CKBatchedGEMM_config(b, m, n, k)
if splitK == None:
if ck_config != None:
splitK = ck_config["splitK"]
else:
splitK = 0
Y = torch.empty(b, m, n, dtype=dtype, device=XQ.device)
return batched_gemm_a8w8(XQ, WQ, x_scale, w_scale, Y, bias, splitK)
@compile_ops("module_batched_gemm_a8w8_tune", fc_name="batched_gemm_a8w8_tune")
def batched_gemm_a8w8_tune(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
out: Tensor,
kernelId: int,
splitK=0,
): ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from typing import Optional
import functools
import pandas as pd
from ..jit.core import (
compile_ops,
AITER_CORE_DIR,
)
from ..utility import dtypes
from ..jit.utils.chip_info import get_cu_num
@compile_ops("module_batched_gemm_bf16", fc_name="batched_gemm_bf16")
def batched_gemm_bf16(
XQ: Tensor, WQ: Tensor, out: Tensor, bias: Optional[Tensor] = None, splitK=0
): ...
@functools.lru_cache(maxsize=1024)
def compute_batched_gemm_SplitK(
M: int, N: int, K: int, tile_m: int, tile_n: int, tile_k: int
):
cu_num = get_cu_num()
tile_num = ((M + tile_m - 1) // tile_m) * ((N + tile_n - 1) // tile_n)
cusPerTile = cu_num / tile_num
splitK = 0
while cusPerTile >= pow(2, splitK + 1) and (pow(2, splitK + 1) * tile_k) < 2 * K:
splitK += 1
return splitK
@functools.lru_cache(maxsize=1024)
def get_CKBatchedGEMM_config(
B: int,
M: int,
N: int,
K: int,
):
if not hasattr(get_CKBatchedGEMM_config, "ck_batched_gemm_dict"):
ck_batched_gemm_dict = pd.read_csv(
f"{AITER_CORE_DIR}/aiter/configs/bf16_tuned_batched_gemm.csv"
).drop_duplicates()
get_CKBatchedGEMM_config.ck_batched_gemm_dict = ck_batched_gemm_dict.set_index(
["B", "M", "N", "K"]
).to_dict("index")
config = get_CKBatchedGEMM_config.ck_batched_gemm_dict.get((B, M, N, K), None)
if config != None:
mnk = config["kernelName"].split("_")[2].split("x")[1:]
config["tile_m"] = int(mnk[0])
config["tile_n"] = int(mnk[1])
config["tile_k"] = int(mnk[2])
return config
def batched_gemm_bf16_CK(
XQ: Tensor,
WQ: Tensor,
bias: Optional[Tensor] = None,
dtype=dtypes.bf16,
splitK: Optional[int] = None,
):
assert dtype in [
dtypes.bf16,
dtypes.fp16,
], f"Output {dtype=} is currently not supported in batched_gemm_bf16"
b = XQ.shape[0]
m = XQ.shape[1]
n = WQ.shape[1]
k = XQ.shape[2]
ck_config = get_CKBatchedGEMM_config(b, m, n, k)
if splitK == None:
if ck_config != None:
splitK = ck_config["splitK"]
else:
splitK = 0
Y = torch.empty(b, m, n, dtype=dtype, device=XQ.device)
return batched_gemm_bf16(XQ, WQ, Y, bias, splitK)
@compile_ops("module_batched_gemm_bf16_tune", fc_name="batched_gemm_bf16_tune")
def batched_gemm_bf16_tune(
XQ: Tensor, WQ: Tensor, out: Tensor, kernelId: int, splitK=0
): ...
# SPDX-License-Identifier: MIT
from torch import Tensor
from ..jit.core import compile_ops
MD_NAME = "module_cache"
@compile_ops("module_cache")
def swap_blocks(src: Tensor, dst: Tensor, block_mapping: Tensor) -> None: ...
@compile_ops("module_cache")
def copy_blocks(
key_caches: Tensor, value_caches: Tensor, block_mapping: Tensor
) -> None: ...
@compile_ops("module_cache")
def reshape_and_cache(
key: Tensor,
value: Tensor,
key_cache: Tensor,
value_cache: Tensor,
slot_mapping: Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
asm_layout: bool,
) -> None: ...
@compile_ops("module_cache")
def reshape_and_cache_flash(
key: Tensor,
value: Tensor,
key_cache: Tensor,
value_cache: Tensor,
slot_mapping: Tensor,
kv_cache_dtype: str,
k_scale: Tensor,
v_scale: Tensor,
) -> None: ...
@compile_ops("module_cache")
def reshape_and_cache_with_pertoken_quant(
key: Tensor,
value: Tensor,
key_cache: Tensor,
value_cache: Tensor,
k_dequant_scales: Tensor,
v_dequant_scales: Tensor,
slot_mapping: Tensor,
asm_layout: bool,
) -> None: ...
@compile_ops("module_cache")
def reshape_and_cache_with_block_quant(
key: Tensor,
value: Tensor,
key_cache: Tensor,
value_cache: Tensor,
k_dequant_scales: Tensor,
v_dequant_scales: Tensor,
slot_mapping: Tensor,
asm_layout: bool,
) -> None: ...
@compile_ops("module_cache")
def convert_fp8(
dst_cache: Tensor, src_cache: Tensor, scale: float, kv_cache_dtype: str
) -> None: ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
import torch.distributed as dist
from ..dist.parallel_state import (
ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce,
get_tp_group,
destroy_model_parallel,
destroy_distributed_environment,
)
from ..dist.utils import get_open_port, get_distributed_init_method, get_ip
import aiter
import logging
logger = logging.getLogger("aiter")
def init_dist_env(world_size, rankID):
set_custom_all_reduce(True)
init_distributed_environment(
world_size=world_size,
rank=rankID,
distributed_init_method=get_distributed_init_method(get_ip(), get_open_port()),
)
ensure_model_parallel_initialized(world_size, 1)
# hack custom_allreduce
tp_grp = get_tp_group()
ca_comm = tp_grp.ca_comm
# signal
signal = torch.zeros(world_size * 64, dtype=torch.int64, device=rankID)
ca_comm.signal = signal
ca_comm.register_buffer(signal)
logger.debug(f"RANK: {rankID}/{world_size} init_dist_env...")
def destroy_dist_env():
if dist.is_initialized():
destroy_model_parallel()
destroy_distributed_environment()
torch.cuda.empty_cache()
"""
def all_reduce_asm(inp: torch.Tensor):
tp_grp = get_tp_group()
ca = tp_grp.ca_comm
if ca._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return aiter.all_reduce_asm_(
inp, ca._ptr, ca.signal, ca.buffer, ca._IS_CAPTURING
)
else:
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(inp)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return aiter.all_reduce_asm_(
inp, ca._ptr, ca.signal, ca.buffer, ca._IS_CAPTURING
)
def all_reduce_rmsnorm(
input: Tensor, residual_in: Tensor, weight: Tensor, bias: Tensor, epsilon: float
):
tp_grp = get_tp_group()
ca = tp_grp.ca_comm
return aiter.all_reduce_rmsnorm_(
input,
residual_in,
weight,
bias,
epsilon,
ca._ptr,
ca.signal,
ca.buffer,
ca._IS_CAPTURING,
)
def all_reduce_rmsnorm_quant(
input: Tensor,
residual_in: Tensor,
xscale: Tensor,
weight: Tensor,
bias: Tensor,
epsilon: float,
):
tp_grp = get_tp_group()
ca = tp_grp.ca_comm
return aiter.all_reduce_rmsnorm_quant_(
input,
residual_in,
xscale,
weight,
bias,
epsilon,
ca._ptr,
ca.signal,
ca.buffer,
ca._IS_CAPTURING,
)
"""
# SPDX-License-Identifier: MIT
from torch import Tensor
from ..jit.core import compile_ops
MD_NAME = "module_custom"
@compile_ops("module_custom")
def wvSpltK(in_a: Tensor, in_b: Tensor, out_c: Tensor, N_in: int, CuCount: int): ...
@compile_ops("module_custom")
def LLMM1(in_a: Tensor, in_b: Tensor, out_c: Tensor, rows_per_block: int): ...
# SPDX-License-Identifier: MIT
from typing import List, Optional, Tuple
import torch
from ..jit.core import compile_ops
MD_NAME = "module_custom_all_reduce"
@compile_ops("module_custom_all_reduce")
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[torch.Tensor],
offsets: List[int],
rank: int,
fully_connected: bool,
) -> int: ...
@compile_ops("module_custom_all_reduce")
def all_reduce(
_fa: int,
inp: torch.Tensor,
out: torch.Tensor,
open_fp8_quant: bool,
reg_buffer: Optional[torch.Tensor] = None,
) -> None: ...
@compile_ops("module_custom_all_reduce")
def all_gather_reg(_fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: ...
@compile_ops("module_custom_all_reduce")
def all_gather_unreg(
_fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None: ...
@compile_ops("module_custom_all_reduce")
def fused_allreduce_rmsnorm(
_fa: int,
inp: torch.Tensor,
res_inp: torch.Tensor,
res_out: torch.Tensor,
out: torch.Tensor,
w: torch.Tensor,
eps: float,
reg_buffer: Optional[torch.Tensor] = None,
) -> None: ...
def all_reduce_asm_fake_tensor(
inp: torch.Tensor,
ca: int,
reg_sig: torch.Tensor,
reg_buffer: torch.Tensor,
isGraph: bool,
) -> torch.Tensor:
return torch.empty_like(
inp,
dtype=inp.dtype,
device=inp.device,
)
@compile_ops("module_custom_all_reduce", gen_fake=all_reduce_asm_fake_tensor)
def all_reduce_asm_(
inp: torch.Tensor,
ca: int,
reg_sig: torch.Tensor,
reg_buffer: torch.Tensor,
isGraph: bool,
) -> torch.Tensor: ...
def all_reduce_rmsnorm_fake_tensors(
input: torch.Tensor,
residual_in: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
epsilon: float,
ca: int,
reg_sig: torch.Tensor,
reg_buffer: torch.Tensor,
isGraph: bool,
) -> List[torch.Tensor]:
output = torch.empty_like(
input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad
)
residual_out = torch.empty_like(
input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad
)
return [output, residual_out]
@compile_ops("module_custom_all_reduce", gen_fake=all_reduce_rmsnorm_fake_tensors)
def all_reduce_rmsnorm_(
input: torch.Tensor,
residual_in: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
epsilon: float,
ca: int,
reg_sig: torch.Tensor,
reg_buffer: torch.Tensor,
isGraph: bool,
) -> List[torch.Tensor]: ...
# def all_reduce_rmsnorm_quant_fake_tensors(
# input: torch.Tensor,
# residual_in: torch.Tensor,
# weight: torch.Tensor,
# xscale: torch.Tensor,
# bias: torch.Tensor,
# epsilon: float,
# ca: int,
# reg_sig: torch.Tensor,
# reg_buffer: torch.Tensor,
# isGraph: bool,
# ) -> List[torch.Tensor]:
# N = input.size(-1)
# M = input.numel() // N
# output = torch.empty_like(
# input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad
# )
# residual_out = torch.empty_like(
# input, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad
# )
# y_scale = torch.empty((M, 1), dtype=torch.float32, device=input.device)
# return [output, residual_out, y_scale]
# @compile_ops("module_custom_all_reduce", gen_fake=all_reduce_rmsnorm_quant_fake_tensors)
# def all_reduce_rmsnorm_quant_(
# input: torch.Tensor,
# residual_in: torch.Tensor,
# weight: torch.Tensor,
# xscale: torch.Tensor,
# bias: torch.Tensor,
# epsilon: float,
# ca: int,
# reg_sig: torch.Tensor,
# reg_buffer: torch.Tensor,
# isGraph: bool,
# ) -> List[torch.Tensor]: ...
@compile_ops("module_custom_all_reduce")
def dispose(_fa: int) -> None: ...
@compile_ops("module_custom_all_reduce")
def meta_size() -> int: ...
@compile_ops("module_custom_all_reduce")
def register_buffer(
_fa: int, t: torch.Tensor, handles: List[torch.Tensor], offsets: List[int]
) -> None: ...
# def gen_get_graph_buffer_ipc_meta_fake_tensors(_fa: int) -> List[torch.Tensor]:
# handle_sz = 64 # sizeof(hipIpcMemHandle_t) is 64 byte
# num_buffers = 4 # ???
# handles = torch.empty((handle_sz * num_buffers,), dtype=torch.uint8, device="cuda")
# offset_tensor = torch.empty((num_buffers,), dtype=torch.int64, device="cuda")
# return [handles, offset_tensor]
@compile_ops("module_custom_all_reduce")
def get_graph_buffer_ipc_meta(_fa: int) -> Tuple[torch.Tensor, torch.Tensor]: ...
@compile_ops("module_custom_all_reduce")
def register_graph_buffers(
_fa: int, handles: List[torch.Tensor], offsets: List[torch.Tensor]
) -> None: ...
@compile_ops("module_custom_all_reduce")
def allocate_meta_buffer(size: int) -> torch.Tensor: ...
# def get_meta_buffer_ipc_handle_fake(inp: torch.Tensor) -> torch.Tensor:
# handle_size = 64
# if not inp.is_cuda:
# raise RuntimeError("Input tensor must be on CUDA device")
# return torch.empty(handle_size, dtype=torch.uint8, device=inp.device)
@compile_ops("module_custom_all_reduce")
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: ...
\ No newline at end of file
from ..jit.core import compile_ops
# from enum import Enum as Enum
Enum = int
@compile_ops("module_aiter_enum", "ActivationType")
def _ActivationType(dummy: int) -> int: ...
@compile_ops("module_aiter_enum", "QuantType")
def _QuantType(dummy: int) -> int: ...
ActivationType = type(_ActivationType(0))
QuantType = type(_QuantType(0))
# SPDX-License-Identifier: MIT
from torch import Tensor
from ..jit.core import compile_ops
from typing import List, Optional
@compile_ops("module_fused_qk_norm_mrope_cache_quant_shuffle")
def fused_qk_norm_mrope_3d_cache_pts_quant_shuffle(
qkv: Tensor,
qw: Tensor,
kw: Tensor,
cos_sin: Tensor,
positions: Tensor,
num_tokens: int,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_size: int,
is_neox_style: bool,
mrope_section_: List[int],
is_interleaved: bool,
eps: float,
q_out: Tensor,
k_cache: Tensor,
v_cache: Tensor,
slot_mapping: Tensor,
per_tensor_k_scale: Tensor,
per_tensor_v_scale: Tensor,
k_out: Optional[Tensor],
v_out: Optional[Tensor],
return_kv: bool,
use_shuffle_layout: bool,
block_size: int,
x: int,
rotary_dim: int = 0,
) -> None: ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from ..jit.core import compile_ops
from typing import Optional
@compile_ops("module_fused_qk_norm_rope_cache_quant_shuffle")
def fused_qk_norm_rope_cache_quant_shuffle(
qkv: Tensor,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_dim: int,
eps: float,
qw: Tensor,
kw: Tensor,
cos_sin_cache: Tensor,
is_neox_style: bool,
pos_ids: Tensor,
k_cache: Tensor,
v_cache: Tensor,
slot_mapping: Tensor,
kv_cache_dtype: str,
k_scale: Tensor,
v_scale: Tensor,
) -> None: ...
def gen_fused_qk_rmsnorm_fake_tensor(
q: Tensor,
q_weight: Tensor,
q_eps: float,
k: Tensor,
k_weight: Tensor,
k_eps: float,
q_out: Optional[Tensor],
k_out: Optional[Tensor],
) -> tuple[Tensor, Tensor]:
if q_out is None:
q_out = torch.empty_like(q, dtype=q.dtype, device=q.device)
if k_out is None:
k_out = torch.empty_like(k, dtype=k.dtype, device=k.device)
return q_out, k_out
@compile_ops("module_fused_qk_norm_rope_cache_quant_shuffle")
def fused_qk_norm_rope_cache_block_quant_shuffle(
qkv: Tensor,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_dim: int,
eps: float,
qw: Tensor,
kw: Tensor,
cos_sin_cache: Tensor,
is_neox_style: bool,
pos_ids: Tensor,
k_cache: Tensor,
v_cache: Tensor,
slot_mapping: Tensor,
cu_q_len: Tensor,
kv_cache_dtype: str,
k_scale: Tensor,
v_scale: Tensor,
max_tokens_per_batch: int = 0,
) -> None: ...
@compile_ops("module_fused_qk_norm_rope_cache_quant_shuffle")
def fused_qk_norm_rope_cache_pts_quant_shuffle(
qkv: Tensor,
qw: Tensor,
kw: Tensor,
cos_sin: Tensor,
positions: Tensor,
num_tokens: int,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_size: int,
is_neox_style: bool,
eps: float,
q_out: Tensor,
k_cache: Tensor,
v_cache: Tensor,
slot_mapping: Tensor,
per_tensor_k_scale: Tensor,
per_tensor_v_scale: Tensor,
k_out: Optional[Tensor],
v_out: Optional[Tensor],
return_kv: bool,
use_shuffle_layout: bool,
block_size: int,
x: int,
rotary_dim: int = 0,
) -> None: ...
@compile_ops("module_fused_qk_norm_rope_cache_quant_shuffle")
def fused_qk_norm_rope_2way(
q0: Tensor,
k0: Tensor,
q1: Tensor,
k1: Tensor,
w_q0: Tensor,
w_k0: Tensor,
w_q1: Tensor,
w_k1: Tensor,
cos_sin0: Tensor,
cos_sin1: Tensor,
batch_size: int,
num_tokens0: int,
num_tokens1: int,
num_heads_q: int,
num_heads_k: int,
head_size: int,
is_interleaved: bool,
eps: float,
out_q01: Tensor,
out_k01: Tensor,
) -> None: ...
# SPDX-License-Identifier: MIT
import torch
from torch import Tensor
from typing import Optional
import functools
import pandas as pd
from ..jit.core import (
compile_ops,
AITER_CORE_DIR,
)
from ..utility import dtypes
from ..jit.utils.chip_info import get_cu_num
@compile_ops("module_gemm_a8w8", fc_name="gemm_a8w8")
def gemm_a8w8(
XQ: torch.Tensor,
WQ: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
Out: torch.Tensor,
bias: Optional[torch.Tensor] = None,
splitK: int = 0,
) -> torch.Tensor: ...
@compile_ops("module_gemm_a8w8_asm", fc_name="gemm_a8w8_asm")
def gemm_a8w8_asm(
XQ: Tensor, # A:[M, K] i8
WQ: Tensor, # B:[N, K] i8 -> shuffle layout(32,16)
x_scale: Tensor, # A_scale:[M, 1] f32
w_scale: Tensor, # B_scale:[1, N] f32
Out: Tensor, # Out:[M, N] bf16
bias: Tensor, # bias:[1, N] f32
sub_m: Optional[int] = 128,
sub_n: Optional[int] = 128,
pad_a: Optional[int] = 0,
pad_b: Optional[int] = 0,
pad_c: Optional[int] = 0,
splitK: Optional[int] = 0,
) -> torch.Tensor: ...
@compile_ops("module_gemm_a8w8_blockscale", fc_name="gemm_a8w8_blockscale")
def gemm_a8w8_blockscale(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
out: Tensor,
): ...
@compile_ops("module_gemm_a8w8_blockscale_asm", fc_name="flatmm_a8w8_blockscale_asm")
def flatmm_a8w8_blockscale_asm(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
out: Tensor,
): ...
@functools.lru_cache(maxsize=1024)
def compute_gemm_SplitK(M: int, N: int, K: int, tile_m: int, tile_n: int, tile_k: int):
cu_num = get_cu_num()
tile_num = ((M + tile_m - 1) // tile_m) * ((N + tile_n - 1) // tile_n)
cusPerTile = cu_num / tile_num
splitK = 0
while cusPerTile >= pow(2, splitK + 1) and (pow(2, splitK + 1) * tile_k) < 2 * K:
splitK += 1
return splitK
@functools.lru_cache(maxsize=1024)
def get_CKGEMM_config(
M: int,
N: int,
K: int,
):
if not hasattr(get_CKGEMM_config, "ckgemm_dict"):
ckgemm_dict = pd.read_csv(
f"{AITER_CORE_DIR}/aiter/configs/a8w8_tuned_gemm.csv"
).drop_duplicates()
get_CKGEMM_config.ckgemm_dict = ckgemm_dict.set_index(["M", "N", "K"]).to_dict(
"index"
)
config = get_CKGEMM_config.ckgemm_dict.get((M, N, K), None)
if config != None:
mnk = config["kernelName"].split("_")[2].split("x")[1:]
config["tile_m"] = int(mnk[0])
config["tile_n"] = int(mnk[1])
config["tile_k"] = int(mnk[2])
return config
@functools.lru_cache(maxsize=1024)
def get_ASMGEMM_config(M: int, N: int, K: int, bias: bool, dtype: torch.dtype):
if not hasattr(get_ASMGEMM_config, "asmgemm_dict"):
asmGemmDictDf = pd.read_csv(
f"{AITER_CORE_DIR}/aiter/configs/asm_a8w8_gemm.csv"
).drop_duplicates()
asmGemmDictDf.bias = asmGemmDictDf.bias.apply(
lambda s: True if s in ["True", 1, "true"] else False
)
get_ASMGEMM_config.asmgemm_dict = asmGemmDictDf.set_index(
["M", "N", "K", "bias", "outdtype"]
).to_dict("index")
return get_ASMGEMM_config.asmgemm_dict.get((M, N, K, bias, str(dtype)), None)
def gemm_a8w8_ASM(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
bias: Tensor,
dtype=dtypes.bf16,
check=False,
):
"""
Notes for use gemm_a8w8_ASM:
1. WQ(weight) must be shuffle, you can use \
'weightshuffle = shuffle_weight(weight,layout=(32,16))'
2. Use asm gemm must give bias, if not have bias, please give \
'bias=torch.zeros(n,dtype=dtypes.fp32,device='cuda')'
"""
if check:
assert dtype in [
dtypes.bf16,
], f"Output {dtype=} is currently not supported in gemm_a8w8_ASM"
assert (
x_scale.dtype == dtypes.fp32 and w_scale.dtype == dtypes.fp32
), f"{x_scale.dtype=} or {w_scale.dtype=} must be dtypes.fp32"
m = XQ.shape[0]
n = WQ.shape[0]
k = XQ.shape[-1]
if (
x_scale.dtype == dtypes.fp32
and w_scale.dtype == dtypes.fp32
and (asm_config := get_ASMGEMM_config(m, n, k, bias != None, dtype)) != None
):
assert (
bias != None
), "Use asm gemm must give bias, please give a \
bias=torch.zeros(n,dtype=dtypes.fp32,device='cuda')"
splitK = asm_config["splitK"]
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)
return gemm_a8w8_asm(XQ, WQ, x_scale, w_scale, Y, bias, splitK=splitK)
return None
def gemm_a8w8_CK(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
bias: Optional[Tensor] = None,
dtype=dtypes.bf16,
splitK: Optional[int] = None,
):
assert dtype in [
dtypes.bf16,
dtypes.fp16,
], f"Output {dtype=} is currently not supported in gemm_a8w8"
m = XQ.shape[0]
n = WQ.shape[0]
k = XQ.shape[-1]
ck_config = get_CKGEMM_config(m, n, k)
if splitK == None:
if ck_config != None:
splitK = ck_config["splitK"]
else:
splitK = 0
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)
return gemm_a8w8(XQ, WQ, x_scale, w_scale, Y, bias, splitK)
def gemm_a8w8_blockscale_CK(
XQ: Tensor, WQ: Tensor, x_scale: Tensor, w_scale: Tensor, dtype=dtypes.bf16
):
assert dtype in [
dtypes.bf16,
dtypes.fp16,
], f"Output {dtype=} is currently not supported in gemm_a8w8"
m = XQ.shape[0]
n = WQ.shape[0]
k = XQ.shape[-1]
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)
return gemm_a8w8_blockscale(XQ, WQ, x_scale, w_scale, Y)
def flatmm_a8w8_blockscale_ASM(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
dtype=dtypes.fp16,
):
assert dtype in [
dtypes.fp16,
], f"Output {dtype=} is currently not supported in gemm_a8w8"
m = XQ.shape[0]
n = WQ.shape[0]
k = XQ.shape[-1]
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)
return flatmm_a8w8_blockscale_asm(XQ, WQ, x_scale, w_scale, Y)
@compile_ops("module_gemm_a8w8_tune", fc_name="gemm_a8w8_tune")
def gemm_a8w8_tune(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
out: Tensor,
kernelId: int,
splitK=0,
): ...
@compile_ops("module_gemm_a8w8_blockscale_tune", fc_name="gemm_a8w8_blockscale_tune")
def gemm_a8w8_blockscale_tune(
XQ: Tensor,
WQ: Tensor,
x_scale: Tensor,
w_scale: Tensor,
out: Tensor,
kernelId: int,
splitK=0,
): ...
# SPDX-License-Identifier: MIT
import torch
from typing import Optional
from ..jit.core import compile_ops
@compile_ops("module_hipbsolgemm")
def hipb_create_extension() -> None: ...
@compile_ops("module_hipbsolgemm")
def hipb_destroy_extension() -> None: ...
def gen_hipb_mm_fake_tensor(
mat1: torch.Tensor,
mat2: torch.Tensor,
solution_index: int,
bias: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
scaleA: Optional[torch.Tensor] = None,
scaleB: Optional[torch.Tensor] = None,
scaleOut: Optional[torch.Tensor] = None,
scaleType: Optional[int] = None,
):
mat1_sizes = mat1.size()
mat2_sizes = mat2.size()
in_dtype = mat1.dtype
out_dtype = out_dtype if out_dtype is not None else in_dtype
result = torch.empty(
(mat1_sizes[0], mat2_sizes[1]), dtype=out_dtype, device=mat1.device
)
return result
@compile_ops("module_hipbsolgemm", gen_fake=gen_hipb_mm_fake_tensor)
def hipb_mm(
mat1: torch.Tensor,
mat2: torch.Tensor,
solution_index: int,
bias: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
scaleA: Optional[torch.Tensor] = None,
scaleB: Optional[torch.Tensor] = None,
scaleOut: Optional[torch.Tensor] = None,
scaleType: Optional[int] = None,
) -> torch.Tensor: ...
@compile_ops("module_hipbsolgemm")
def hipb_findallsols(
mat1: torch.Tensor,
mat2: torch.Tensor,
bias: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
scaleA: Optional[torch.Tensor] = None,
scaleB: Optional[torch.Tensor] = None,
scaleC: Optional[torch.Tensor] = None,
scaleType: Optional[int] = None,
) -> list[int]: ...
@compile_ops("module_hipbsolgemm")
def getHipblasltKernelName() -> None: ...
@compile_ops("module_rocsolgemm")
def rocb_create_extension() -> None: ...
@compile_ops("module_rocsolgemm")
def rocb_destroy_extension() -> None: ...
def gen_rocb_mm_fake_tensor(
arg0: torch.Tensor, arg1: torch.Tensor, arg2: int
) -> torch.Tensor:
mat1_sizes = arg0.size()
mat2_sizes = arg0.size()
in_dtype = arg0.dtype
result = torch.empty(
(mat1_sizes[0], mat2_sizes[1]), dtype=in_dtype, device=arg0.device
)
return result
@compile_ops("module_rocsolgemm", gen_fake=gen_rocb_mm_fake_tensor)
def rocb_mm(arg0: torch.Tensor, arg1: torch.Tensor, arg2: int) -> torch.Tensor: ...
@compile_ops("module_rocsolgemm")
def rocb_findallsols(arg0: torch.Tensor, arg1: torch.Tensor) -> list[int]: ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment