Unverified Commit 946ab56c authored by Xu Kai's avatar Xu Kai Committed by GitHub
Browse files

[feature] add gptq for inference (#4754)

* [gptq] add gptq kernel (#4416)

* add gptq

* refactor code

* fix tests

* replace auto-gptq

* rname inferance/quant

* refactor test

* add auto-gptq as an option

* reset requirements

* change assert and check auto-gptq

* add import warnings

* change test flash attn version

* remove example

* change requirements of flash_attn

* modify tests

* [skip ci] change requirements-test

* [gptq] faster gptq cuda kernel (#4494)

* [skip ci] add cuda kernels

* add license

* [skip ci] fix max_input_len

* format files & change test size

* [skip ci]

* [gptq] add gptq tensor parallel (#4538)

* add gptq tensor parallel

* add gptq tp

* delete print

* add test gptq check

* add test auto gptq check

* [gptq] combine gptq and kv cache manager (#4706)

* combine gptq and kv cache manager

* add init bits

* delete useless code

* add model path

* delete usless print and update test

* delete usless import

* move option gptq to shard config

* change replace linear to shardformer

* update bloom policy

* delete useless code

* fix import bug and delete uselss code

* change colossalai/gptq to colossalai/quant/gptq

* update import linear for tests

* delete useless code and mv gptq_kernel to kernel directory

* fix triton kernel

* add triton import
parent 1e0e0808
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _tuning_h
#define _tuning_h
struct ExLlamaTuning
{
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};
#endif
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess
#define _cuda_check(fn) \
do { \
{_cuda_err = fn;} \
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
} while(false)
// React to failure on return code == 0
#define _alloc_check(fn) \
do { \
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
else _cuda_err = cudaSuccess; \
} while(false)
#endif
...@@ -6,6 +6,7 @@ try: ...@@ -6,6 +6,7 @@ try:
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .rms_norm import rmsnorm_forward from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax from .softmax import softmax
...@@ -20,6 +21,7 @@ try: ...@@ -20,6 +21,7 @@ try:
"copy_kv_cache_to_dest", "copy_kv_cache_to_dest",
"rotary_embedding_fwd", "rotary_embedding_fwd",
"token_attention_fwd", "token_attention_fwd",
"gptq_fused_linear_triton",
] ]
except ImportError: except ImportError:
......
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
import torch
import triton
import triton.language as tl
from auto_gptq.nn_modules.triton_utils import custom_autotune
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def cosh(x):
exp_x = tl.exp(x)
return (exp_x + 1.0 / exp_x) * 0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@triton.jit
def relu(x):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
return tl.where(x >= 0, x, 0.0)
@triton.jit
def squared_relu(x):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_sq = x * x
return tl.where(x > 0.0, x_sq, 0.0)
@triton.jit
def star_relu(x):
"""
Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper.
.. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf
"""
x_sq = x * x
return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472
# Leaky ReLU
@triton.jit
def leaky_relu(x):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
return tl.where(x >= 0.0, x, 0.01 * x)
@triton.jit
def gelu(x):
"""
GeLU_ activation - Gaussian error linear unit
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))
@triton.jit
def smelu(x):
"""
SmeLU_ activation - Smooth ReLU with beta=2.0
.. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
"""
beta = 2.0
relu = tl.where(x >= beta, x, 0.0)
return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu)
@triton.jit
def silu(x):
return x * tl.sigmoid(x)
@custom_autotune.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def cai_gptq_matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
bias_ptr,
residual_ptr,
M,
N,
K,
bits,
maxq,
gptq_group_size,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
QKV_FUSED: tl.constexpr,
ADD_BIAS: tl.constexpr,
ADD_RESIDUAL: tl.constexpr,
ACT_TYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
NK = K
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K)
qkv_offset = pid // (num_pid_m * num_pid_n)
pid = pid % (num_pid_m * num_pid_n)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
# offs_bk = offs_k + qkv_offset * NK
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = (
b_ptr
+ qkv_offset * N * NK // infearure_per_bits
+ ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
# g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :]
zeros_ptrs = (
zeros_ptr
+ qkv_offset * NK * N // gptq_group_size // infearure_per_bits
+ (offs_bn[None, :] // infearure_per_bits)
)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
g_idx_base = tl.arange(0, BLOCK_SIZE_K)
g_idx_base = g_idx_base // gptq_group_size
g_idx = g_idx_base
# tl.device_print("gidx, ", g_idx)
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
for k in range(0, num_pid_k):
# g_idx = tl.load(g_ptrs)
# if (k + 1) * BLOCK_SIZE_K > currend_group_end:
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros).to(tl.float16) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size
# if (k + 2) * BLOCK_SIZE_K > currend_group_end:
c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
if ADD_BIAS:
bias_mask = offs_bn < N
offs_bn += qkv_offset * N
bias_ptrs = bias_ptr + stride_cn * offs_bn
bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
accumulator += bias[None, :]
if ACT_TYPE == 1:
accumulator = relu(accumulator)
elif ACT_TYPE == 2:
accumulator = gelu(accumulator)
elif ACT_TYPE == 3:
accumulator = silu(accumulator)
if ADD_RESIDUAL:
residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
res = tl.load(residual_ptrs, mask=c_mask, other=0.0)
accumulator += res
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8}, num_stages=2, num_warps=4
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def cai_gptq_idx_matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
idx_ptr,
bias_ptr,
residual_ptr,
M,
N,
K,
bits,
maxq,
gptq_group_size,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
QKV_FUSED: tl.constexpr,
ADD_BIAS: tl.constexpr,
ADD_RESIDUAL: tl.constexpr,
ACT_TYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
NK = K
# if QKV_FUSED:
# NK = K//3
# else:
# NK = K
# NK = K
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K)
qkv_offset = pid // (num_pid_m * num_pid_n)
pid = pid % (num_pid_m * num_pid_n)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
# offs_bk = offs_k + qkv_offset * NK
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = (
b_ptr
+ qkv_offset * N * NK // infearure_per_bits
+ ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
# g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :]
zeros_ptrs = (
zeros_ptr
+ qkv_offset * NK * N // gptq_group_size // infearure_per_bits
+ (offs_bn[None, :] // infearure_per_bits)
)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
g_ptrs = idx_ptr + offs_k
g_idx = tl.load(g_ptrs)
# tl.device_print("gidx, ", g_idx)
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros).to(tl.float16) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
if ADD_BIAS:
bias_mask = offs_bn < N
offs_bn += qkv_offset * N
bias_ptrs = bias_ptr + stride_cn * offs_bn
bias = tl.load(bias_ptrs, mask=bias_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
accumulator += bias[None, :]
if ACT_TYPE == 1:
accumulator = relu(accumulator)
elif ACT_TYPE == 2:
accumulator = gelu(accumulator)
elif ACT_TYPE == 3:
accumulator = silu(accumulator)
if ADD_RESIDUAL:
residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
res = tl.load(residual_ptrs, mask=c_mask, other=0.0)
accumulator += res
tl.store(c_ptrs, accumulator, mask=c_mask)
def gptq_fused_linear_triton(
input,
qweight,
scales,
qzeros,
bias,
residual,
bits,
maxq,
gptq_group_size,
qkv_fused,
add_bias,
add_residual,
g_idx=None,
act_type=0,
):
# print("gptq fused ", qkv_fused, add_bias, add_residual)
assert input.is_cuda, "input is not in cuda"
assert qweight.is_cuda, "qweight is not in cuda"
assert scales.is_cuda, "scales is not in cuda"
assert qzeros.is_cuda, "qzeros is not in cuda"
with torch.cuda.device(input.device):
if qkv_fused:
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"])
* 3,
)
output = torch.empty((input.shape[0] * 3, qweight.shape[1]), device=input.device, dtype=torch.float16)
else:
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
# print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype)
if g_idx is None:
cai_gptq_matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
bias,
residual,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
gptq_group_size,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
QKV_FUSED=qkv_fused,
ADD_BIAS=add_bias,
ADD_RESIDUAL=add_residual,
ACT_TYPE=act_type,
)
else:
cai_gptq_idx_matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
bias,
residual,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
gptq_group_size,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
QKV_FUSED=qkv_fused,
ADD_BIAS=add_bias,
ADD_RESIDUAL=add_residual,
ACT_TYPE=act_type,
)
if qkv_fused:
return output.view(3, input.shape[0], qweight.shape[1])
else:
return output
...@@ -32,10 +32,13 @@ class ShardConfig: ...@@ -32,10 +32,13 @@ class ShardConfig:
enable_fused_normalization: bool = False enable_fused_normalization: bool = False
enable_flash_attention: bool = False enable_flash_attention: bool = False
enable_jit_fused: bool = False enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
enable_all_optimization: bool = False enable_all_optimization: bool = False
inference_only: bool = False inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@property @property
......
import argparse
import logging
import os
import time
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def print_perf_stats(latency_set, config, bs, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
latency_set.sort()
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
num_bytes = 2 # float16
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
def bench_bloom(args):
pretrained_model_dir = args.path
quantized_model_dir = args.quantized_path
max_batch_size = args.batch_size
max_input_len = args.input_len
max_output_len = args.output_len
tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir)
tokenizer.pad_token = tokenizer.eos_token
# load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
device=torch.cuda.current_device(),
inject_fused_attention=False)
model = model.half()
model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
}
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
inference_only=True,
inference_gptq=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
# prepare data for generation
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
"attention_mask": torch.ones((max_batch_size, max_input_len))
}
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
# print(f" input_tokens[{t}].shape: {input_tokens[t].shape}")
iters = 10
times = []
for i in range(iters):
torch.cuda.synchronize()
start = time.time()
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
end = time.time()
out_len = outputs.shape[1]
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
times.append((end - start) / (out_len - max_input_len))
print_perf_stats(times, model_config, max_batch_size)
def check_bloom(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
bench_bloom(args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom(args):
spawn(check_bloom, args.tp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
args = parser.parse_args()
test_bloom(args)
import argparse
import logging
import os
import time
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
from torch import distributed as dist
from torch.profiler import ProfilerActivity, profile, record_function
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline
import colossalai
from colossalai.gptq import CaiQuantLinear
from colossalai.gptq.gptq_tp import replace_autogptq_linear
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return
def print_perf_stats(latency_set, config, bs, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
latency_set.sort()
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
num_bytes = 2
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
def run_llama_test(args):
pretrained_model_dir = args.path
quantized_model_dir = args.quantized_path
max_batch_size = args.batch_size
max_input_len = args.input_len
max_output_len = args.output_len
tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
# load quantized model to the first GPU
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
device=torch.cuda.current_device(),
inject_fused_attention=False)
init_to_get_rotary(model.model.model, base=10000)
model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False,
inference_only=True,
inference_gptq=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'),
"attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda')
}
iters = 10
times = []
for i in range(iters):
torch.cuda.synchronize()
start = time.time()
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
end = time.time()
out_len = outputs.shape[1]
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
times.append((end - start) / (out_len - max_input_len))
print_perf_stats(times, model_config, max_batch_size)
def check_llama(rank, world_size, port, args):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test(args)
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama(args):
spawn(check_llama, args.tp_size, args=args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='Model path', required=True)
parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True)
parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size')
parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length')
parser.add_argument('--output_len', type=int, default=128, help='Maximum output length')
args = parser.parse_args()
test_llama(args)
import os
import torch
import re
from .builder import Builder
from .utils import append_nvcc_threads, get_cuda_cc_flag
class GPTQBuilder(Builder):
NAME = "cu_gptq"
PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq"
def __init__(self):
super().__init__(name=GPTQBuilder.NAME,
prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH)
def include_dirs(self):
ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()]
return ret
def sources_files(self):
ret = [
self.csrc_abs_path(fname) for fname in [
'gptq/linear_gptq.cpp',
'gptq/column_remap.cu',
'gptq/cuda_buffers.cu',
'gptq/q4_matmul.cu',
'gptq/q4_matrix.cu'
]
]
return ret
def cxx_flags(self):
return ['-O3'] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ['-v',
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK', "-lcublas", "-std=c++17"
]
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 80:
extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)
\ No newline at end of file
...@@ -18,3 +18,4 @@ SentencePiece ...@@ -18,3 +18,4 @@ SentencePiece
ninja ninja
flash_attn==2.0.5 flash_attn==2.0.5
datasets datasets
#auto-gptq now not support torch1.12
import math
import time
import numpy as np
import pytest
import torch
import torch.nn as nn
import transformers
from packaging import version
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
try:
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
from exllama_kernels import prepare_buffers, set_tuning_params
from colossalai.inference.quant.gptq import CaiQuantLinear
HAS_AUTO_GPTQ = True
except:
HAS_AUTO_GPTQ = False
print("please install AutoGPTQ from https://github.com/PanQiWei/AutoGPTQ")
import warnings
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
HAS_GPTQ_CUDA = False
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
max_inner_outer_dim = 1
max_input_len = 1
max_dq_buffer_size = 1
gptq_temp_dq_buffer = None
gptq_temp_state_buffer = None
def init_buffer(cai_linear, use_act_order=False):
global max_dq_buffer_size
global max_input_len
global max_dq_buffer_size
global max_inner_outer_dim
global gptq_temp_dq_buffer
global gptq_temp_state_buffer
max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8)
if use_act_order:
max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures)
if use_act_order:
max_input_len = 4096
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim),
dtype=torch.float16,
device=torch.cuda.current_device())
gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ,
reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq")
def test_gptq_linear():
infeature = 1024
outfeature = 1024
group_size = 128
wbits = 4
inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device())
batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device())
device = torch.device("cuda:0")
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits)
linear = linear_class(
bits=4,
group_size=group_size,
infeatures=infeature,
outfeatures=outfeature,
bias=False,
)
torch.manual_seed(42)
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
linear.scales = linear.scales + 0.002
linear = linear.to(device)
cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True)
cai_linear.qweight.data.copy_(linear.qweight)
cai_linear.scales = cai_linear.scales + 0.002
cai_linear = cai_linear.to(device)
linear = autogptq_post_init(linear, use_act_order=False)
max_inner_outer_dim = max(infeature, outfeature)
max_dq_buffer_size = linear.infeatures * linear.outfeatures
max_input_len = 2048
buffers = {
"temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device),
"temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
}
prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"])
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
with torch.no_grad():
gptq_out = linear(inps)
batch_gptq_out = linear(batch_inps)
torch.cuda.synchronize()
cai_out = cai_linear(inps)
torch.cuda.synchronize()
batch_cai_out = cai_linear(batch_inps)
torch.cuda.synchronize()
assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01)
assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01)
if __name__ == "__main__":
test_gptq_linear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment