Commit 85762c1a authored by Xiaowei.zhang's avatar Xiaowei.zhang
Browse files

Init the main branch for aiter

parent ae0b3521
Pipeline #3505 canceled with stages
[submodule "3rdparty/composable_kernel"]
path = 3rdparty/composable_kernel
url = ../composable_kernel
branch = rel-5.7.1
[submodule "3rdparty/moe_c"]
path = 3rdparty/moe_c
url = ../Moe
branch = W8A8
Copyright ©
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
graft aiter
graft aiter_meta
\ No newline at end of file
## Installation
method build for develop:
```
git submodule update --init
python setup.py develop
```
method build for whl package:
```
bash das_build.sh
```
If you happen to forget the `--recursive` during `clone`, you can use the following command after `cd aiter`
```
git submodule sync && git submodule update --init --recursive
```
## Run operators supported by aiter
There are number of op test, you can run them with: `python3 op_tests/test_layernorm2d.py`
| **Ops** | **Description** |
|-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|ELEMENT WISE | ops: + - * / |
|SIGMOID | (x) = 1 / (1 + e^-x) |
|AllREDUCE | Reduce + Broadcast |
|KVCACHE | W_K W_V |
|MHA | Multi-Head Attention |
|MLA | Multi-head Latent Attention with [KV-Cache layout](https://docs.flashinfer.ai/tutorials/kv_layout.html#page-table-layout ) |
|PA | Paged Attention |
|FusedMoe | Mixture of Experts |
|QUANT | BF16/FP16 -> FP8/INT4 |
|RMSNORM | root mean square |
|LAYERNORM | x = (x - u) / (σ2 + ϵ) e*0.5 |
|ROPE | Rotary Position Embedding |
|GEMM | D=αAβB+C |
# SPDX-License-Identifier: MIT
import torch
import os
import logging
logger = logging.getLogger("aiter")
def getLogger():
global logger
if not logger.handlers:
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
if int(os.environ.get("AITER_LOG_MORE", 0)):
formatter = logging.Formatter(
fmt="[%(name)s %(levelname)s] %(asctime)s.%(msecs)03d - %(processName)s:%(process)d - %(pathname)s:%(lineno)d - %(funcName)s\n%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
else:
formatter = logging.Formatter(
fmt="[%(name)s] %(message)s",
)
console_handler.setFormatter(formatter)
console_handler.setLevel(logging.INFO)
logger.addHandler(console_handler)
if hasattr(torch._dynamo.config, "ignore_logger_methods"):
torch._dynamo.config.ignore_logger_methods = (
logging.Logger.info,
logging.Logger.warning,
logging.Logger.debug,
logger.warning,
logger.info,
logger.debug,
)
return logger
logger = getLogger()
import importlib.util
if importlib.util.find_spec("aiter_") is not None:
from aiter_ import *
from .jit import core
# from .ops.enum import *
from .ops.norm import *
from .ops.quant import *
# from .ops.gemm_op_a8w8 import *
# from .ops.batched_gemm_op_a8w8 import *
# from .ops.batched_gemm_op_bf16 import *
from .ops.aiter_operator import *
from .ops.activation import *
# from .ops.attention import *
# from .ops.custom import *
from .ops.custom_all_reduce import *
from .ops.moe_op import *
from .ops.moe_c_op import *
from .ops.moe_sorting import *
from .ops.pos_encoding import *
# from .ops.cache import *
from .ops.rmsnorm import *
from .ops.awq_gemm_asm import *
from .ops.awq_dq_asm import *
# from .ops.communication import *
from .ops.rope import *
from .ops.topk import *
# from .ops.mha import *
from .ops.gradlib import *
# from .ops.trans_ragged_layout import *
# from . import mla
from .utility import dtypes,fp4_utils
import torch
import torch.nn.functional as F
import ctypes
from typing import Optional
import aiter
from aiter import ActivationType, QuantType, dtypes
from aiter.ops.awq_gemm_asm import *
from aiter.ops.shuffle import reverse_awq_order
from aiter.ops.awq_gemm_asm import awq_gemm_asm
from aiter.ops.awq_dq_asm import awq_dq_asm
def pack_int4_to_int8(low_4bits):
if len(low_4bits) % 2 != 0:
low_4bits = torch.cat([low_4bits, torch.tensor([0], dtype=torch.uint8)])
# 3. 将相邻两个低4位拼成一个 int8 值
# 偶数索引:左移4位作为高4位;奇数索引:低4位
packed = (low_4bits[::2]) | (low_4bits[1::2] << 4)
packed = packed.to(torch.int8) # 转回 int8(有符号)
return packed
def pack_int4_to_int8_64K(low_4bits):
if len(low_4bits) % 2 != 0:
low_4bits = torch.cat([low_4bits, torch.tensor([0], dtype=torch.uint8)])
# 3. 将相邻两个低4位拼成一个 int8 值
# 偶数索引:左移4位作为高4位;奇数索引:低4位
packed = (low_4bits[::128]) | (low_4bits[64::128] << 4)
packed = packed.to(torch.int8) # 转回 int8(有符号)
return packed
# qweight - [K, N // 8]
# qzeros - [K // G, N // 8]
# scales - [K // G, N]
def asm_awq_reorder_and_repack(
qweight: torch.Tensor,
qzeros: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
#WARNING: Only support awq group_size=64
N = qweight.shape[1] * 8
K = qweight.shape[0]
G = K // qzeros.shape[0]
assert K // qzeros.shape[0] == 64, "[ERROR] ASM_AWQ_GEMM not support K Groupsize other than 64!"
# assert (N % 512==0 or N==576), "[ERROR]ASM_AWQ_GEMM Not support Weight N other than 576 or multiplies of 512!"
device = qzeros.device
bits = 4
shifts = torch.arange(0, 32, bits, device=device)
iweights = torch.bitwise_right_shift(
qweight[:, :, None],
shifts[None, None, :],
).to(torch.int8)
iweights = iweights.view(iweights.shape[0], -1)
zeros = torch.bitwise_right_shift(
qzeros[:, :, None],
shifts[None, None, :],
).to(torch.int8)
zeros = zeros.view(qzeros.shape[0], -1)
zeros = reverse_awq_order(zeros)
iweights = reverse_awq_order(iweights)
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
iweights_packed = iweights.view(K, -1, 2)
zeros_packed = zeros.view(K//G, -1, 2)
# Repack weight to int32 and pack along the K direction
# [K, N] -> [N, K]
# iweights = iweights.transpose(1, 0).contiguous()
packed_weights = torch.zeros([K, N//2], dtype=torch.int8, device=qweight.device)
packed_zeros = torch.zeros([K//G, N//2], dtype=torch.int8, device=zeros.device)
for i in range(2):
packed_weights |= (iweights_packed[:, :, i].to(torch.int8) << (i * bits))
packed_zeros |= (zeros_packed[:, :, i].to(torch.int8) << (i * bits))
return packed_weights,packed_zeros
def asm_awq_post_dequant_torch(
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
group_size: int,
) -> torch.Tensor:
"""Dequantize weights using PyTorch implementation.
Args:
qweight: Quantized weight tensor
scales: Scale factors tensor
qzeros: Zero points tensor
group_size: Size of groups for quantization
Returns:
Dequantized tensor
"""
if group_size == -1:
group_size = qweight.shape[0]
bits = 4
shifts = torch.arange(0, 8, bits, device=qzeros.device)
#只需要8 bit 展开
iweights = torch.bitwise_right_shift(
qweight[:, :, None],
shifts[None, None, :],
).to(torch.int8)
iweights = iweights.view(iweights.shape[0], -1)
zeros = torch.bitwise_right_shift(
qzeros[:, :, None],
shifts[None, None, :],
).to(torch.int8)
zeros = zeros.view(qzeros.shape[0], -1)
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
return (iweights - zeros) * scales
def asm_awq_post_dequant(
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
group_size: int,
) -> torch.Tensor:
K = scales.shape[0] * group_size
N = scales.shape[-1]
# device = scales.device
out = torch.empty((K, N), dtype=scales.dtype, device=qweight.device)
awq_dq_asm(out, qweight, qzeros, scales)
return out
# The inference function
# input - [m, k]
# qweight - [n, k // 2]
# qzeros - [k//g, n//2]
# scales - [k//g, n]
def asm_awq_gemm_a16w4(input: torch.tensor,
qweight: torch.tensor,
scales: torch.tensor,
qzeros: torch.tensor) -> torch.tensor:
M,K = input.shape
N = scales.shape[1]
assert K % 256 == 0
device = qzeros.device
out_asm = torch.empty((M, N),
dtype=input.dtype,
device=device)
awq_gemm_asm(out_asm, qweight, input, qzeros, scales)
# out_asm = out_asm.reshape(out_asm.shape[1], -1)
return out_asm
\ No newline at end of file
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"),
0,
repeat(indices, "z -> z d", d=second_dim),
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(
0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (
(attention_mask + unused_mask) if unused_mask is not None else attention_mask
)
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
"""
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
```
[
[2, 3, 0, 0, 0, 0],
[3, 2, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0]
]
```
, which refers to the 3D-attention mask:
```
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
```.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(
seqlen, device=length.device, dtype=length.dtype
).expand(len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(
attention_mask_in_length.flatten(), as_tuple=False
).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
import os
from pathlib import Path
import functools
import pandas as pd
import torch
import torch.nn.functional as F
from aiter import hipb_create_extension, hipb_mm, getHipblasltKernelName
from aiter import rocb_create_extension, rocb_mm
from aiter import logger, dtypes
from aiter.jit.utils.torch_guard import torch_compile_guard
from typing import Optional
extensions_created = False
@torch_compile_guard()
def scale_mm(
inp: torch.Tensor,
weights: torch.Tensor,
bias: Optional[torch.Tensor] = None,
otype: Optional[torch.dtype] = None,
scale_a: Optional[torch.Tensor] = None,
scale_b: Optional[torch.Tensor] = None,
scale_c: Optional[torch.Tensor] = None,
scale_type: Optional[int] = None,
)-> torch.Tensor:
# scale_type=0, scalar scale
# scale_type=1, channel scale
# scale_type=2, block scale
global extensions_created
if otype is None:
otype = inp.dtype
if extensions_created == False:
hipb_create_extension()
extensions_created = True
if inp.dim() >= 3:
assert(False, "not support 3dim input")
inp_view = inp
return hipb_mm(inp_view, weights.t(), -1, bias, otype, scale_a, scale_b, scale_c, scale_type)
\ No newline at end of file
M,N,K
16, 1536, 7168
16, 3072, 1536
16, 576, 7168
16, 7168, 256
16, 7168, 2048
16, 4608, 7168
16, 7168, 2304
16, 512, 7168
16, 4096, 512
32, 1536, 7168
32, 3072, 1536
32, 576, 7168
32, 7168, 256
32, 7168, 2048
32, 4608, 7168
32, 7168, 2304
32, 512, 7168
32, 4096, 512
64, 1536, 7168
64, 3072, 1536
64, 576, 7168
64, 7168, 256
64, 7168, 2048
64, 4608, 7168
64, 7168, 2304
64, 512, 7168
64, 4096, 512
128, 1536, 7168
128, 3072, 1536
128, 576, 7168
128, 7168, 256
128, 7168, 2048
128, 4608, 7168
128, 7168, 2304
128, 512, 7168
128, 4096, 512
256, 1536, 7168
256, 3072, 1536
256, 576, 7168
256, 7168, 256
256, 7168, 2048
256, 4608, 7168
256, 7168, 2304
256, 512, 7168
256, 4096, 512
512, 1536, 7168
512, 3072, 1536
512, 576, 7168
512, 7168, 256
512, 7168, 2048
512, 4608, 7168
512, 7168, 2304
512, 512, 7168
512, 4096, 512
1024, 1536, 7168
1024, 3072, 1536
1024, 576, 7168
1024, 7168, 256
1024, 7168, 2048
1024, 4608, 7168
1024, 7168, 2304
1024, 512, 7168
1024, 4096, 512
1536, 1536, 7168
1536, 3072, 1536
1536, 576, 7168
1536, 7168, 256
1536, 7168, 2048
1536, 4608, 7168
1536, 7168, 2304
1536, 512, 7168
1536, 4096, 512
2048, 1536, 7168
2048, 3072, 1536
2048, 576, 7168
2048, 7168, 256
2048, 7168, 2048
2048, 4608, 7168
2048, 7168, 2304
2048, 512, 7168
2048, 4096, 512
4096, 1536, 7168
4096, 3072, 1536
4096, 576, 7168
4096, 7168, 256
4096, 7168, 2048
4096, 4608, 7168
4096, 7168, 2304
4096, 512, 7168
4096, 4096, 512
8192, 1536, 7168
8192, 3072, 1536
8192, 576, 7168
8192, 7168, 256
8192, 7168, 2048
8192, 4608, 7168
8192, 7168, 2304
8192, 512, 7168
8192, 4096, 512
16384, 1536, 7168
16384, 3072, 1536
16384, 576, 7168
16384, 7168, 256
16384, 7168, 2048
16384, 4608, 7168
16384, 7168, 2304
16384, 512, 7168
16384, 4096, 512
20480, 1536, 7168
20480, 3072, 1536
20480, 576, 7168
20480, 7168, 256
20480, 7168, 2048
20480, 4608, 7168
20480, 7168, 2304
20480, 512, 7168
20480, 4096, 512
\ No newline at end of file
B,M,N,K
16, 1, 1280, 8192
16, 32, 1280, 8192
16, 64, 1280, 8192
16, 128, 1280, 8192
16, 192, 1280, 8192
16, 256, 1280, 8192
16, 320, 1280, 8192
16, 512, 1280, 8192
16, 1024, 1280, 8192
16, 2048, 1280, 8192
16, 4096, 1280, 8192
16, 8192, 1280, 8192
16, 16384, 1280, 8192
16, 1, 8192, 1024
16, 32, 8192, 1024
16, 64, 8192, 1024
16, 128, 8192, 1024
16, 192, 8192, 1024
16, 256, 8192, 1024
16, 320, 8192, 1024
16, 512, 8192, 1024
16, 1024, 8192, 1024
16, 2048, 8192, 1024
16, 4096, 8192, 1024
16, 8192, 8192, 1024
16, 16384, 8192, 1024
M,N,K
1, 1280, 8192
32, 1280, 8192
64, 1280, 8192
128, 1280, 8192
192, 1280, 8192
256, 1280, 8192
320, 1280, 8192
512, 1280, 8192
1024, 1280, 8192
2048, 1280, 8192
4096, 1280, 8192
8192, 1280, 8192
16384, 1280, 8192
1, 8192, 1024
32, 8192, 1024
64, 8192, 1024
128, 8192, 1024
192, 8192, 1024
256, 8192, 1024
320, 8192, 1024
512, 8192, 1024
1024, 8192, 1024
2048, 8192, 1024
4096, 8192, 1024
8192, 8192, 1024
16384, 8192, 1024
M,N,K,bias,outdtype,splitK,us
128,1280,8192,True,torch.bfloat16,3,13.85
192,1280,8192,True,torch.bfloat16,3,13.90
256,1280,8192,True,torch.bfloat16,3,25.54
320,1280,8192,True,torch.bfloat16,3,25.56
512,1280,8192,True,torch.bfloat16,3,48.06
1024,1280,8192,True,torch.bfloat16,3,94.45
2048,1280,8192,True,torch.bfloat16,3,186.90
4096,1280,8192,True,torch.bfloat16,3,371.85
8192,1280,8192,True,torch.bfloat16,3,742.90
16384,1280,8192,True,torch.bfloat16,3,1483.54
128,8192,1024,True,torch.bfloat16,0,12.67
192,8192,1024,True,torch.bfloat16,0,12.70
256,8192,1024,True,torch.bfloat16,0,23.80
320,8192,1024,True,torch.bfloat16,0,23.82
512,8192,1024,True,torch.bfloat16,0,44.61
1024,8192,1024,True,torch.bfloat16,0,76.33
2048,8192,1024,True,torch.bfloat16,0,140.05
4096,8192,1024,True,torch.bfloat16,0,277.80
8192,8192,1024,True,torch.bfloat16,0,552.69
16384,8192,1024,True,torch.bfloat16,0,1095.12
\ No newline at end of file
{
"tunedCSV": "tuned_awq_gemm_NN.csv",
"kernels": [
{
"solutionId": 0,
"kernel_name": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_2",
"co_file": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 64, "mt1": 32, "numThreads": 512, "wgm": 1 }
},
{
"solutionId": 1,
"kernel_name": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x64x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_2",
"co_file": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x64x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 64, "mt1": 64, "numThreads": 512, "wgm": 1 }
},
{
"solutionId": 2,
"kernel_name": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x128x32_SN_K1_PGR6_SB1_TT2_8_WG16_16_2",
"co_file": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x128x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 64, "mt1": 128, "numThreads": 512, "wgm": 1 }
},
{
"solutionId": 3,
"kernel_name": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT16x32x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_2",
"co_file": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT16x32x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 16, "mt1": 32, "numThreads": 512, "wgm": 1 }
},
{
"solutionId": 4,
"kernel_name": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_3",
"co_file": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT4_2_w4a16_splitK.co",
"Kconfigs": { "mt0": 64, "mt1": 32, "numThreads": 768, "wgm": 1 }
},
{
"solutionId": 5,
"kernel_name": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x64x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_3",
"co_file": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x64x32_SN_K1_PGR6_SB1_TT4_2_w4a16_splitK.co",
"Kconfigs": { "mt0": 64, "mt1": 64, "numThreads": 768, "wgm": 1 }
},
{
"solutionId": 6,
"kernel_name": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x128x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_3",
"co_file": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT64x128x32_SN_K1_PGR6_SB1_TT4_2_w4a16_splitK.co",
"Kconfigs": { "mt0": 64, "mt1": 128, "numThreads": 768, "wgm": 1 }
}
],
"Untunedkernels": [
{
"solutionId": 4,
"kernel_name": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT32x32x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_3",
"co_file": "Cijk_Ailk_Bljk_HHS_BH_UserArgs_MT32x32x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 32, "mt1": 32, "numThreads": 768, "wgm": 1 }
}
]
}
{
"tunedCSV": "tuned_awq_bf16_gemm_NN.csv",
"kernels": [
{
"solutionId": 0,
"kernel_name": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_2",
"co_file": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 64, "mt1": 32, "numThreads": 512, "wgm": 1 }
},
{
"solutionId": 1,
"kernel_name": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x64x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_2",
"co_file": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x64x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 64, "mt1": 64, "numThreads": 512, "wgm": 1 }
},
{
"solutionId": 2,
"kernel_name": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x128x32_SN_K1_PGR6_SB1_TT2_8_WG16_16_2",
"co_file": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x128x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 64, "mt1": 128, "numThreads": 512, "wgm": 1 }
},
{
"solutionId": 3,
"kernel_name": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT16x32x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_2",
"co_file": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT16x32x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 16, "mt1": 32, "numThreads": 512, "wgm": 1 }
},
{
"solutionId": 4,
"kernel_name": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_3",
"co_file": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x32x32_SN_K1_PGR6_SB1_TT4_2_w4a16_splitK.co",
"Kconfigs": { "mt0": 64, "mt1": 32, "numThreads": 768, "wgm": 1 }
},
{
"solutionId": 5,
"kernel_name": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x64x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_3",
"co_file": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x64x32_SN_K1_PGR6_SB1_TT4_2_w4a16_splitK.co",
"Kconfigs": { "mt0": 64, "mt1": 64, "numThreads": 768, "wgm": 1 }
},
{
"solutionId": 6,
"kernel_name": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x128x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_3",
"co_file": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT64x128x32_SN_K1_PGR6_SB1_TT4_2_w4a16_splitK.co",
"Kconfigs": { "mt0": 64, "mt1": 128, "numThreads": 768, "wgm": 1 }
}
],
"Untunedkernels": [
{
"solutionId": 4,
"kernel_name": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT32x32x32_SN_K1_PGR6_SB1_TT2_2_WG16_16_3",
"co_file": "Cijk_Ailk_Bljk_BBS_BH_UserArgs_MT32x32x32_SN_K1_PGR6_SB1_TT4_2_w4a16.co",
"Kconfigs": { "mt0": 32, "mt1": 32, "numThreads": 768, "wgm": 1 }
}
]
}
GPU_ARCHS=gfx936 python3 gradlib/gradlib/gemm_tuner.py \
--tuned_file aiter/configs/asm_tune/tuned_awq_gemm_NN.csv \
--inputSols_file aiter/configs/asm_tune/awq_NN_solutions.json \
--input_file aiter/configs/asm_tune/untuned_awqgemm_NN.csv \
--warmupIters 1 --runIters 3 --fastNoCheck 1 --hsacoOnly 1
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
B,M,N,K
16, 1, 1280, 8192
16, 32, 1280, 8192
16, 64, 1280, 8192
16, 128, 1280, 8192
16, 192, 1280, 8192
16, 256, 1280, 8192
16, 320, 1280, 8192
16, 512, 1280, 8192
16, 1024, 1280, 8192
16, 2048, 1280, 8192
16, 4096, 1280, 8192
16, 8192, 1280, 8192
16, 16384, 1280, 8192
16, 1, 8192, 1024
16, 32, 8192, 1024
16, 64, 8192, 1024
16, 128, 8192, 1024
16, 192, 8192, 1024
16, 256, 8192, 1024
16, 320, 8192, 1024
16, 512, 8192, 1024
16, 1024, 8192, 1024
16, 2048, 8192, 1024
16, 4096, 8192, 1024
16, 8192, 8192, 1024
16, 16384, 8192, 1024
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment