Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoGPTQ
Commits
da900c3b
Commit
da900c3b
authored
Sep 19, 2024
by
yangql
Browse files
Initial commit
parents
Changes
195
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3133 additions
and
0 deletions
+3133
-0
auto_gptq/utils/perplexity_utils.py
auto_gptq/utils/perplexity_utils.py
+223
-0
autogptq_extension/exllamav2/config.h
autogptq_extension/exllamav2/config.h
+13
-0
autogptq_extension/exllamav2/cpp/util.h
autogptq_extension/exllamav2/cpp/util.h
+12
-0
autogptq_extension/exllamav2/ext_hip.cpp
autogptq_extension/exllamav2/ext_hip.cpp
+136
-0
autogptq_extension/exllamav2/hip/compat.cuh
autogptq_extension/exllamav2/hip/compat.cuh
+58
-0
autogptq_extension/exllamav2/hip/compat_gemm.cuh
autogptq_extension/exllamav2/hip/compat_gemm.cuh
+40
-0
autogptq_extension/exllamav2/hip/matrix_view.cuh
autogptq_extension/exllamav2/hip/matrix_view.cuh
+124
-0
autogptq_extension/exllamav2/hip/q_gemm.cuh
autogptq_extension/exllamav2/hip/q_gemm.cuh
+36
-0
autogptq_extension/exllamav2/hip/q_gemm.hip
autogptq_extension/exllamav2/hip/q_gemm.hip
+214
-0
autogptq_extension/exllamav2/hip/q_gemm_kernel.cuh
autogptq_extension/exllamav2/hip/q_gemm_kernel.cuh
+489
-0
autogptq_extension/exllamav2/hip/q_gemm_kernel_gptq.cuh
autogptq_extension/exllamav2/hip/q_gemm_kernel_gptq.cuh
+226
-0
autogptq_extension/exllamav2/hip/q_matrix.cuh
autogptq_extension/exllamav2/hip/q_matrix.cuh
+75
-0
autogptq_extension/exllamav2/hip/q_matrix.hip
autogptq_extension/exllamav2/hip/q_matrix.hip
+630
-0
autogptq_extension/exllamav2/hip/quant/qdq_2.cuh
autogptq_extension/exllamav2/hip/quant/qdq_2.cuh
+106
-0
autogptq_extension/exllamav2/hip/quant/qdq_3.cuh
autogptq_extension/exllamav2/hip/quant/qdq_3.cuh
+171
-0
autogptq_extension/exllamav2/hip/quant/qdq_4.cuh
autogptq_extension/exllamav2/hip/quant/qdq_4.cuh
+230
-0
autogptq_extension/exllamav2/hip/quant/qdq_5.cuh
autogptq_extension/exllamav2/hip/quant/qdq_5.cuh
+210
-0
autogptq_extension/exllamav2/hip/quant/qdq_6.cuh
autogptq_extension/exllamav2/hip/quant/qdq_6.cuh
+46
-0
autogptq_extension/exllamav2/hip/quant/qdq_8.cuh
autogptq_extension/exllamav2/hip/quant/qdq_8.cuh
+41
-0
autogptq_extension/exllamav2/hip/quant/qdq_util.cuh
autogptq_extension/exllamav2/hip/quant/qdq_util.cuh
+53
-0
No files found.
auto_gptq/utils/perplexity_utils.py
0 → 100644
View file @
da900c3b
import
sys
import
numpy
as
np
import
torch
from
datasets
import
load_dataset
from
tqdm
import
tqdm
class
Perplexity
:
"""
A class for calculating the perplexity of a language model.
"""
def
__init__
(
self
,
model
,
tokenizer
,
dataset_path
=
"wikitext"
,
dataset_name
=
None
,
split
=
"test"
,
text_column
=
"text"
,
):
"""
Calculate perplexity using the same method as seen in llama.cpp.
Parameters
----------
model : AutoModelForCausalLM
The language model for which the perplexity is calculated.
tokenizer : AutoTokenizer
The tokenizer corresponding to the model.
device : str, optional
The device to run the calculations on. If auto, the device that your model uses
will be the device used for these calculations. Default is 'auto'.
dataset_path : str, optional
The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'.
dataset_name : str, optional
The name of the dataset. Default is None.
split : str, optional
The split of the dataset to use. Default is 'test'.
text_column : str, optional
The name of the column in the dataset that contains the text data. Default is 'text'.
"""
self
.
_model
=
model
self
.
_tokenizer
=
tokenizer
self
.
_dataset_path
=
dataset_path
self
.
_dataset_name
=
dataset_name
self
.
_split
=
split
self
.
_text_column
=
text_column
self
.
_text
=
self
.
_prepare_data
()
def
_get_device
(
self
):
if
torch
.
backends
.
mps
.
is_available
():
return
"mps"
elif
torch
.
cuda
.
is_available
():
return
"cuda:0"
else
:
return
"cpu"
def
_prepare_data
(
self
):
"""
Prepares the dataset by loading and formatting.
Returns
-------
str
The formatted dataset as a single string.
"""
if
self
.
_dataset_path
==
"wikitext"
:
self
.
_dataset_name
=
"wikitext-2-raw-v1"
# Load the dataset
data
=
load_dataset
(
self
.
_dataset_path
,
self
.
_dataset_name
,
split
=
self
.
_split
)
# Format the text column of the dataset
text_list
=
[
"
\n
"
if
s
==
""
else
s
for
s
in
data
[
self
.
_text_column
]]
return
""
.
join
(
text_list
)
@
staticmethod
def
softmax
(
logits
):
"""
Static method for applying the softmax function.
Parameters
----------
logits : np.ndarray
The input to the softmax function.
Returns
-------
np.ndarray
The output of the softmax function.
"""
e_x
=
np
.
exp
(
logits
-
np
.
max
(
logits
))
return
e_x
/
e_x
.
sum
(
axis
=
0
)
def
calculate_perplexity
(
self
,
n_ctx
=
512
,
n_batch
=
512
):
"""
Calculates the perplexity of the language model.
Parameters
----------
n_ctx : int
The context size.
n_batch : int
The batch size.
Returns
-------
list
The list of perplexity scores calculated.
"""
# Tokenize the text
self
.
_tokenizer
.
model_max_length
=
sys
.
maxsize
tokens
=
self
.
_tokenizer
(
self
.
_text
,
truncation
=
False
,
return_tensors
=
"pt"
).
input_ids
.
to
(
self
.
_model
.
device
)
nll
=
0.0
# Negative log likelihood
count
=
0
# Counter for processed tokens
curr_ppl
=
0
all_perplexity
=
[]
with
tqdm
(
range
(
len
(
tokens
[
0
])
//
n_ctx
),
desc
=
"Perplexity: - "
)
as
progress
:
for
i
in
progress
:
# Process each batch of tokens
nll
,
count
=
self
.
_process_batch
(
i
,
n_ctx
,
n_batch
,
tokens
,
nll
,
count
)
# Calculate and display the current perplexity
curr_ppl
=
np
.
exp
(
nll
/
count
)
all_perplexity
.
append
(
curr_ppl
)
progress
.
set_description
(
f
"Perplexity:
{
curr_ppl
:.
4
f
}
"
)
return
all_perplexity
def
_process_batch
(
self
,
i
,
n_ctx
,
n_batch
,
tokens
,
nll
,
count
):
"""
Processes each batch of tokens.
Parameters
----------
i : int
The batch index.
n_ctx : int
The context size.
n_batch : int
The batch size.
tokens : torch.Tensor
The tokenized text.
nll : float
The current negative log likelihood.
count : int
The current count of processed tokens.
Returns
-------
float
The updated negative log likelihood.
int
The updated count of processed tokens.
"""
start
=
i
*
n_ctx
end
=
start
+
n_ctx
num_batches
=
(
n_ctx
+
n_batch
-
1
)
//
n_batch
logits
=
[]
for
j
in
range
(
num_batches
):
batch_start
=
start
+
j
*
n_batch
batch_size
=
min
(
end
-
batch_start
,
n_batch
)
token_org
=
tokens
[
0
][
batch_start
].
item
()
if
j
==
0
:
# Replace the first token with the BOS token
tokens
[
0
][
batch_start
]
=
self
.
_tokenizer
.
bos_token_id
# Compute the logits for the current batch of tokens
batch_logits
=
self
.
_compute_batch_logits
(
tokens
,
batch_start
,
batch_size
)
tokens
[
0
][
batch_start
]
=
token_org
logits
.
append
(
batch_logits
)
# We rely on the fact that attention in the forward pass only looks at previous
# tokens here, so the logits returned for each token are an accurate representation
# of what the model would have predicted at that point.
#
# Example, we have a context window of 512, we will compute perplexity for each of the
# last 256 tokens. Then, we split the input up into context window size chunks to
# process the entire prompt.
for
j
in
range
(
min
(
512
,
n_ctx
//
2
),
n_ctx
-
1
):
tok_logits
=
logits
[
0
][
0
][
j
].
cpu
().
numpy
()
# Compute the probability of the next token
prob
=
self
.
softmax
(
tok_logits
)[
tokens
[
0
][
start
+
j
+
1
]]
# Update the negative log likelihood and the count of processed tokens
nll
+=
-
np
.
log
(
prob
,
where
=
prob
>
0
)
count
+=
1
return
nll
,
count
def
_compute_batch_logits
(
self
,
tokens
,
batch_start
,
batch_size
):
"""
Computes the logits for a batch of tokens.
Parameters
----------
tokens : torch.Tensor
The tokenized text.
batch_start : int
The start index of the batch.
batch_size : int
The size of the batch.
Returns
-------
torch.Tensor
The logits for the batch of tokens.
"""
# Compute the logits without keeping track of gradients
with
torch
.
no_grad
():
outputs
=
self
.
_model
(
tokens
[:,
batch_start
:
batch_start
+
batch_size
])
return
outputs
.
logits
.
detach
()
autogptq_extension/exllamav2/config.h
0 → 100644
View file @
da900c3b
#ifndef _config_h
#define _config_h
#define MAX_Q_GEMM_ROWS 50
#define QMODE_2BIT 1
#define QMODE_3BIT 1
#define QMODE_4BIT 1
#define QMODE_5BIT 1
#define QMODE_6BIT 0
#define QMODE_8BIT 0
#endif
autogptq_extension/exllamav2/cpp/util.h
0 → 100644
View file @
da900c3b
#ifndef _util_h
#define _util_h
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#endif
autogptq_extension/exllamav2/ext_hip.cpp
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <torch/extension.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdio>
#include "config.h"
#include "hip/q_matrix.cuh"
#include "hip/q_gemm.cuh"
#include "cpp/util.h"
// Some decluttering macros
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
// Quant matrix
uintptr_t
make_q_matrix
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
torch
::
Tensor
q_invperm
,
torch
::
Tensor
q_scale
,
torch
::
Tensor
q_scale_max
,
torch
::
Tensor
q_groups
,
torch
::
Tensor
gptq_qzeros
,
torch
::
Tensor
gptq_scales
,
torch
::
Tensor
gptq_g_idx
,
torch
::
Tensor
temp_dq
)
{
TORCH_CHECK_DTYPE
(
q_weight
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
q_perm
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
q_invperm
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
q_scale
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
q_scale_max
,
kHalf
);
TORCH_CHECK_DTYPE_OPT
(
q_groups
,
kShort
);
TORCH_CHECK_DTYPE_OPT
(
gptq_qzeros
,
kInt
);
TORCH_CHECK_DTYPE_OPT
(
gptq_scales
,
kHalf
);
TORCH_CHECK_DTYPE_OPT
(
gptq_g_idx
,
kInt
);
TORCH_CHECK_SHAPES
(
q_perm
,
0
,
q_invperm
,
0
,
1
);
int
device
=
q_weight
.
device
().
index
();
int
width
=
q_weight
.
size
(
1
);
int
groups
;
int
height
;
if
(
!
q_scale
.
device
().
is_meta
())
{
TORCH_CHECK_SHAPES
(
q_weight
,
1
,
q_scale
,
1
,
8
);
TORCH_CHECK_SHAPES
(
q_scale_max
,
0
,
q_scale
,
0
,
1
);
groups
=
q_scale
.
size
(
0
);
height
=
q_invperm
.
size
(
0
);
}
else
{
TORCH_CHECK_SHAPES
(
q_weight
,
1
,
gptq_qzeros
,
1
,
8
);
TORCH_CHECK_SHAPES
(
q_weight
,
1
,
gptq_scales
,
1
,
1
);
groups
=
gptq_qzeros
.
size
(
0
);
height
=
q_weight
.
size
(
0
)
*
8
;
}
TORCH_CHECK
(
temp_dq
.
size
(
0
)
>=
width
*
height
,
"Insufficient size of temp_dq buffer"
)
QMatrix
*
m
=
new
QMatrix
(
device
,
height
,
width
,
groups
,
(
uint32_t
*
)
q_weight
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_perm
.
data_ptr
(),
q_invperm
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_invperm
.
data_ptr
(),
q_scale
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
q_scale
.
data_ptr
(),
q_scale_max
.
device
().
is_meta
()
?
NULL
:
(
half
*
)
q_scale_max
.
data_ptr
(),
q_groups
.
device
().
is_meta
()
?
NULL
:
(
uint16_t
*
)
q_groups
.
data_ptr
(),
gptq_qzeros
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
gptq_qzeros
.
data_ptr
(),
gptq_scales
.
device
().
is_meta
()
?
NULL
:
(
half
*
)
gptq_scales
.
data_ptr
(),
gptq_g_idx
.
device
().
is_meta
()
?
NULL
:
(
uint32_t
*
)
gptq_g_idx
.
data_ptr
(),
(
half
*
)
temp_dq
.
data_ptr
()
);
return
reinterpret_cast
<
uintptr_t
>
(
m
);
}
void
gemm_half_q_half
(
torch
::
Tensor
a
,
uintptr_t
b
,
torch
::
Tensor
c
,
bool
force_cuda
)
{
QMatrix
*
qm
=
reinterpret_cast
<
QMatrix
*>
(
b
);
TORCH_CHECK_DTYPE
(
a
,
kHalf
);
TORCH_CHECK_DTYPE
(
c
,
kHalf
);
TORCH_CHECK_SHAPES
(
a
,
0
,
c
,
0
,
1
);
TORCH_CHECK
(
qm
->
height
==
a
.
size
(
1
),
"a and b have incompatible shapes"
)
TORCH_CHECK
(
qm
->
width
==
c
.
size
(
1
),
"b and c have incompatible shapes"
)
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
a
));
gemm_half_q_half_cuda
(
at
::
cuda
::
getCurrentCUDABlasHandle
(),
(
const
half
*
)
a
.
data_ptr
(),
qm
,
(
half
*
)
c
.
data_ptr
(),
c
.
size
(
0
),
// m
c
.
size
(
1
),
// n
a
.
size
(
1
),
// k
true
,
NULL
,
force_cuda
);
}
// Bindings
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"make_q_matrix"
,
&
make_q_matrix
,
"make_q_matrix"
);
m
.
def
(
"gemm_half_q_half"
,
&
gemm_half_q_half
,
"gemm_half_q_half"
);
}
autogptq_extension/exllamav2/hip/compat.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _compat_cuh
#define _compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__
__forceinline__
void
atomicAdd_half
(
half
*
address
,
half
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
((
char
*
)
address
-
((
size_t
)
address
&
2
));
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
__half_raw
hsum
;
hsum
.
x
=
(
size_t
)
address
&
2
?
(
old
>>
16
)
:
(
old
&
0xffff
);
half
tmpres
=
__hadd
(
hsum
,
val
);
hsum
=
__half_raw
(
tmpres
);
old
=
(
size_t
)
address
&
2
?
(
old
&
0xffff
)
|
(
hsum
.
x
<<
16
)
:
(
old
&
0xffff0000
)
|
hsum
.
x
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
old
);
}
while
(
assumed
!=
old
);
}
// atomicAdd for half2 types
__device__
__forceinline__
void
atomicAdd_half2
(
half2
*
address
,
half2
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
address
;
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
half2
old_val
=
*
((
half2
*
)
&
old
);
half2
new_val
=
__hadd2
(
old_val
,
val
);
old
=
atomicCAS
(
address_as_ui
,
assumed
,
*
((
unsigned
int
*
)
&
new_val
));
}
while
(
assumed
!=
old
);
}
//
#if defined(__DTK_ARCH__) || defined(USE_ROCM)
#if __DTK_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd_1
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
#if __DTK_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd_1
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
#endif
#endif
#endif
#endif
autogptq_extension/exllamav2/hip/compat_gemm.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _compat_gemm_cuh
#define _compat_gemm_cuh
#if defined(USE_ROCM)
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
// for symbols as hipblasHalf.
#include <hipblas/hipblas.h>
__host__
__forceinline__
hipblasStatus_t
__compat_hipblasHgemm
(
hipblasHandle_t
handle
,
hipblasOperation_t
transA
,
hipblasOperation_t
transB
,
int
m
,
int
n
,
int
k
,
const
half
*
alpha
,
const
half
*
AP
,
int
lda
,
const
half
*
BP
,
int
ldb
,
const
half
*
beta
,
half
*
CP
,
int
ldc
)
{
return
hipblasHgemm
(
handle
,
transA
,
transB
,
m
,
n
,
k
,
reinterpret_cast
<
const
hipblasHalf
*>
(
alpha
),
reinterpret_cast
<
const
hipblasHalf
*>
(
AP
),
lda
,
reinterpret_cast
<
const
hipblasHalf
*>
(
BP
),
ldb
,
reinterpret_cast
<
const
hipblasHalf
*>
(
beta
),
reinterpret_cast
<
hipblasHalf
*>
(
CP
),
ldc
);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
#endif
autogptq_extension/exllamav2/hip/matrix_view.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "../hip/quant/qdq_util.cuh"
class
MatrixView_half
{
public:
const
half
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_half
(
const
half
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
half
item
(
int
row
,
int
column
)
const
{
return
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
half2
item_half2
(
int
row
,
int
column
)
const
{
return
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
];
}
__device__
__forceinline__
half2
item_half2half2
(
int
row
,
int
column
)
const
{
return
__half2half2
(
data
[
row
*
width
+
column
]);
}
__device__
__forceinline__
const
half
*
item_ptr
(
int
row
,
int
column
)
const
{
return
&
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
void
item4
(
half
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__low2half
(
i01
);
items
[
1
]
=
__high2half
(
i01
);
items
[
2
]
=
__low2half
(
i23
);
items
[
3
]
=
__high2half
(
i23
);
}
__device__
__forceinline__
void
item4_f
(
float
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__half2float
(
__low2half
(
i01
));
items
[
1
]
=
__half2float
(
__high2half
(
i01
));
items
[
2
]
=
__half2float
(
__low2half
(
i23
));
items
[
3
]
=
__half2float
(
__high2half
(
i23
));
}
__device__
__forceinline__
void
item4_h2
(
half2
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__half2half2
(
__low2half
(
i01
));
items
[
1
]
=
__half2half2
(
__high2half
(
i01
));
items
[
2
]
=
__half2half2
(
__low2half
(
i23
));
items
[
3
]
=
__half2half2
(
__high2half
(
i23
));
}
};
class
MatrixView_half_rw
{
public:
half
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_half_rw
(
half
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
half
item
(
int
row
,
int
column
)
const
{
return
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
half2
item_half2
(
int
row
,
int
column
)
const
{
return
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
];
}
__device__
__forceinline__
half
*
item_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
void
set
(
int
row
,
int
column
,
half
value
)
{
data
[
row
*
width
+
column
]
=
value
;
}
__device__
__forceinline__
void
set_half2
(
int
row
,
int
column
,
half2
value
)
{
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
]
=
value
;
}
__device__
__forceinline__
void
set4
(
int
row
,
int
column
,
half
v0
,
half
v1
,
half
v2
,
half
v3
)
{
half2
v01
=
__halves2half2
(
v0
,
v1
);
half2
v23
=
__halves2half2
(
v2
,
v3
);
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
ptr
[
0
]
=
v01
;
ptr
[
1
]
=
v23
;
}
};
class
MatrixView_q4_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q4_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
return
(
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
)
&
0x0f
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
uint32_t
d
=
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
;
items
[
0
]
=
d
&
0x0f
;
items
[
1
]
=
(
d
>>
4
)
&
0x0f
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
uint32_t
d
=
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
;
items
[
0
]
=
d
&
0x0f
;
items
[
1
]
=
(
d
>>
4
)
&
0x0f
;
items
[
2
]
=
(
d
>>
8
)
&
0x0f
;
items
[
3
]
=
(
d
>>
12
)
&
0x0f
;
}
};
#endif
\ No newline at end of file
autogptq_extension/exllamav2/hip/q_gemm.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _q_gemm_cuh
#define _q_gemm_cuh
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/hip/HIPContext.h>
#include "../hip/q_matrix.cuh"
void
gemm_half_q_half_cuda
(
hipblasHandle_t
cublas_handle
,
const
half
*
a
,
QMatrix
*
b
,
half
*
c
,
int
size_m
,
int
size_n
,
int
size_k
,
bool
clear
=
false
,
half
*
reconstruct
=
NULL
,
bool
force_cuda
=
false
);
void
clear_tensor_cuda
(
half
*
c
,
int
size_m
,
int
size_n
);
#endif
\ No newline at end of file
autogptq_extension/exllamav2/hip/q_gemm.hip
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include "../hip/q_gemm.cuh"
#include "../hip/util.cuh"
#include "../hip/matrix_view.cuh"
#include "../config.h"
#include "../hip/quant/qdq_2.cuh"
#include "../hip/quant/qdq_3.cuh"
#include "../hip/quant/qdq_4.cuh"
#include "../hip/quant/qdq_5.cuh"
#include "../hip/quant/qdq_6.cuh"
#include "../hip/quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256
#include "../hip/q_gemm_kernel.cuh"
#include "../hip/q_gemm_kernel_gptq.cuh"
#include "../hip/compat_gemm.cuh"
void gemm_half_q_half_cuda_part
(
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
bool clear
)
{
if (!b->is_gptq)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
hipLaunchKernelGGL(( kernel), dim3(gridDim), dim3(blockDim), 0, 0,
a,
b->cuda_q_weight,
b->cuda_q_scale,
b->cuda_q_scale_max,
c,
size_m,
size_n,
size_k,
b->groups,
b->groupsize,
b->cuda_q_perm,
b->rows_8,
b->rows_6,
b->rows_5,
b->rows_4,
b->rows_3,
b->rows_2,
clear
);
}
else
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
// DBGX((uint64_t) b->cuda_q_perm);
// DBGI(b->rows_4);
// DBGI(b->height);
hipLaunchKernelGGL(( kernel), dim3(gridDim), dim3(blockDim), 0, 0,
a,
b->cuda_q_weight,
b->cuda_gptq_qzeros,
b->cuda_gptq_scales,
c,
size_m,
size_n,
size_k,
b->groups,
b->groupsize,
b->cuda_q_perm,
b->rows_4,
clear
);
}
}
void gemm_half_q_half_cuda
(
hipblasHandle_t cublas_handle,
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
bool clear,
half* temp_dq,
bool force_cuda
)
{
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
{
//printf("cublas\n");
// Reconstruct FP16 matrix, then cuBLAS
if (!temp_dq) temp_dq = b->temp_dq;
b->reconstruct(temp_dq);
//hipblasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
const half alpha = __float2half(1.0f);
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
hipblasHgemm(cublas_handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
size_n, size_m, size_k,
&alpha, temp_dq, size_n,
a, size_k,
&beta, c, size_n);
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//hipblasSgemmEx(cublas_handle,
// HIPBLAS_OP_N,
// HIPBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, HIP_R_16F, size_n,
// a, HIP_R_16F, size_k,
// &beta, c, HIP_R_16F, size_n);
//const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f;
//hipblasGemmEx(cublas_handle,
// HIPBLAS_OP_N, HIPBLAS_OP_N,
// size_n, size_m, size_k,
// &alpha, temp_dq, HIP_R_16F, size_n,
// a, HIP_R_16F, size_k,
// &beta, c, HIP_R_16F, size_n,
// HIP_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
}
else
{
//printf("cuda\n");
// Quantized matmul
//if (clear) clear_tensor_cuda(c, size_m, size_n);
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
int last_chunk_size = size_m - last_chunk;
if (max_chunks)
{
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
}
if (last_chunk_size)
{
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
}
}
}
__global__ void clear_kernel
(
half* __restrict__ c,
const int size_m,
const int size_n
)
{
int m = blockIdx.y;
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
if (n >= size_n) return;
int4* c_ptr = (int4*)(c + m * size_n + n);
*c_ptr = {};
}
void clear_tensor_cuda
(
half* c,
int size_m,
int size_n
)
{
return;
dim3 blockDim, gridDim;
blockDim.x = CLEAR_N_SIZE;
blockDim.y = 1;
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
gridDim.y = size_m;
hipLaunchKernelGGL(( clear_kernel), dim3(gridDim), dim3(blockDim), 0, 0, c, size_m, size_n);
}
autogptq_extension/exllamav2/hip/q_gemm_kernel.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "../hip/compat.cuh"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
__forceinline__
__device__
half2
dot22_8
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
half2
dot22_16
(
half2
(
&
dq
)[
8
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
half2
dot22_32
(
half2
(
&
dq
)[
16
],
const
half
*
a_ptr
,
const
half2
g_result
,
const
half
qs_h
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
+=
1
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hfma2
(
result
,
__halves2half2
(
qs_h
,
qs_h
),
g_result
);
}
__forceinline__
__device__
float
dot22_8_f
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
__forceinline__
__device__
float
dot22_16_f
(
half2
(
&
dq
)[
8
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
__forceinline__
__device__
float
dot22_32_f
(
half2
(
&
dq
)[
16
],
const
half
*
a_ptr
,
const
float
g_result
,
const
float
qs_f
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
+=
1
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
float
result_f
=
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
return
fma
(
result_f
,
qs_f
,
g_result
);
}
typedef
void
(
*
fp_gemm_half_q_half_kernel
)
(
const
half
*
,
const
uint32_t
*
,
const
uint32_t
*
,
const
half
*
,
half
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
uint16_t
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
bool
);
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_kernel
(
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_scale
,
const
half
*
__restrict__
b_q_scale_max
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_n
,
const
int
size_k
,
const
int
groups
,
const
int
groupsize
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
int
rows_8
,
const
int
rows_6
,
const
int
rows_5
,
const
int
rows_4
,
const
int
rows_3
,
const
int
rows_2
,
const
bool
clear
)
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q4_row
b_q_scale_
(
b_q_scale
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
{
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
block_a_ptr
[
t
]
=
a0
;
}
}
// Clear
if
(
n
>=
size_n
)
return
;
if
(
clear
&&
blockIdx
.
z
==
0
)
// && (threadIdx.x & 1) == 0)
{
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
*
((
uint64_t
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
))
=
0
;
}
__syncthreads
();
// Find initial group
int
group
=
offset_k
/
groupsize
;
// Preload scales
float
scales
[
MAX_GROUPS_IN_BLOCK
][
4
];
int
groups_in_block
=
DIVIDE
((
end_k
-
offset_k
),
groupsize
);
for
(
int
g
=
0
;
g
<
groups_in_block
;
g
++
)
{
int
qscales
[
4
];
b_q_scale_
.
item4
(
qscales
,
group
+
g
,
n
);
qscales
[
0
]
++
;
qscales
[
1
]
++
;
qscales
[
2
]
++
;
qscales
[
3
]
++
;
float
maxscale
=
__half2float
(
b_q_scale_max
[
group
+
g
]);
scales
[
g
][
0
]
=
__int2float_rn
(
qscales
[
0
]
*
qscales
[
0
])
*
maxscale
;
scales
[
g
][
1
]
=
__int2float_rn
(
qscales
[
1
]
*
qscales
[
1
])
*
maxscale
;
scales
[
g
][
2
]
=
__int2float_rn
(
qscales
[
2
]
*
qscales
[
2
])
*
maxscale
;
scales
[
g
][
3
]
=
__int2float_rn
(
qscales
[
3
]
*
qscales
[
3
])
*
maxscale
;
}
// a, b offset
int
pre_rows_8
=
min
(
rows_8
,
offset_k
);
int
pre_rows_6
=
offset_k
>
rows_8
?
min
(
rows_6
,
offset_k
)
-
rows_8
:
0
;
int
pre_rows_5
=
offset_k
>
rows_6
?
min
(
rows_5
,
offset_k
)
-
rows_6
:
0
;
int
pre_rows_4
=
offset_k
>
rows_5
?
min
(
rows_4
,
offset_k
)
-
rows_5
:
0
;
int
pre_rows_3
=
offset_k
>
rows_4
?
min
(
rows_3
,
offset_k
)
-
rows_4
:
0
;
int
pre_rows_2
=
offset_k
>
rows_3
?
min
(
rows_2
,
offset_k
)
-
rows_3
:
0
;
int
qk
=
0
;
qk
+=
pre_rows_8
/
32
*
8
;
qk
+=
pre_rows_6
/
32
*
6
;
qk
+=
pre_rows_5
/
32
*
5
;
qk
+=
pre_rows_4
/
32
*
4
;
qk
+=
pre_rows_3
/
32
*
3
;
qk
+=
pre_rows_2
/
32
*
2
;
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
// Initial group
int
scales_idx
=
0
;
float
qs_f0
=
scales
[
scales_idx
][
0
];
float
qs_f1
=
scales
[
scales_idx
][
1
];
float
qs_f2
=
scales
[
scales_idx
][
2
];
float
qs_f3
=
scales
[
scales_idx
][
3
];
int
nextgroup
=
offset_k
+
groupsize
;
// Column result
float
block_c
[
m_count
][
4
]
=
{};
// Dequantize groups
int
k
=
offset_k
;
while
(
k
<
rows_8
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int4
load_int4
[
2
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
4
];
dequant_8bit_8
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
dq
[
0
],
size_n
);
dequant_8bit_8
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
dq
[
1
],
size_n
);
dequant_8bit_8
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
dq
[
2
],
size_n
);
dequant_8bit_8
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
8
;
}
k
+=
32
;
}
while
(
k
<
rows_6
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
int4
load_int4
[
3
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
2
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
8
];
dequant_6bit_16
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
load_int4
[
2
].
x
,
dq
[
0
],
size_n
);
dequant_6bit_16
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
load_int4
[
2
].
y
,
dq
[
1
],
size_n
);
dequant_6bit_16
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
load_int4
[
2
].
z
,
dq
[
2
],
size_n
);
dequant_6bit_16
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
load_int4
[
2
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_16_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_16_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_16_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_16_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
16
;
}
k
+=
32
;
}
while
(
k
<
rows_5
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
1
;
j
++
)
{
int4
load_int4
[
5
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
2
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
3
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
4
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
16
];
dequant_5bit_32
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
load_int4
[
2
].
x
,
load_int4
[
3
].
x
,
load_int4
[
4
].
x
,
dq
[
0
],
size_n
);
dequant_5bit_32
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
load_int4
[
2
].
y
,
load_int4
[
3
].
y
,
load_int4
[
4
].
y
,
dq
[
1
],
size_n
);
dequant_5bit_32
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
load_int4
[
2
].
z
,
load_int4
[
3
].
z
,
load_int4
[
4
].
z
,
dq
[
2
],
size_n
);
dequant_5bit_32
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
load_int4
[
2
].
w
,
load_int4
[
3
].
w
,
load_int4
[
4
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_32_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_32_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_32_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_32_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
32
;
}
k
+=
32
;
}
while
(
k
<
rows_4
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int4
load_int4
[
1
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
4
];
dequant_4bit_8
(
load_int4
[
0
].
x
,
dq
[
0
],
size_n
);
dequant_4bit_8
(
load_int4
[
0
].
y
,
dq
[
1
],
size_n
);
dequant_4bit_8
(
load_int4
[
0
].
z
,
dq
[
2
],
size_n
);
dequant_4bit_8
(
load_int4
[
0
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
8
;
}
k
+=
32
;
}
while
(
k
<
rows_3
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
1
;
j
++
)
{
int4
load_int4
[
3
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
1
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
load_int4
[
2
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
16
];
dequant_3bit_32
(
load_int4
[
0
].
x
,
load_int4
[
1
].
x
,
load_int4
[
2
].
x
,
dq
[
0
],
size_n
);
dequant_3bit_32
(
load_int4
[
0
].
y
,
load_int4
[
1
].
y
,
load_int4
[
2
].
y
,
dq
[
1
],
size_n
);
dequant_3bit_32
(
load_int4
[
0
].
z
,
load_int4
[
1
].
z
,
load_int4
[
2
].
z
,
dq
[
2
],
size_n
);
dequant_3bit_32
(
load_int4
[
0
].
w
,
load_int4
[
1
].
w
,
load_int4
[
2
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_32_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_32_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_32_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_32_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
32
;
}
k
+=
32
;
}
while
(
k
<
rows_2
&&
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
scales_idx
++
;
qs_f0
=
scales
[
scales_idx
][
0
];
qs_f1
=
scales
[
scales_idx
][
1
];
qs_f2
=
scales
[
scales_idx
][
2
];
qs_f3
=
scales
[
scales_idx
][
3
];
nextgroup
+=
groupsize
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
int4
load_int4
[
1
];
load_int4
[
0
]
=
*
((
int4
*
)
b_ptr
);
b_ptr
+=
size_n
;
half2
dq
[
4
][
8
];
dequant_2bit_16
(
load_int4
[
0
].
x
,
dq
[
0
],
size_n
);
dequant_2bit_16
(
load_int4
[
0
].
y
,
dq
[
1
],
size_n
);
dequant_2bit_16
(
load_int4
[
0
].
z
,
dq
[
2
],
size_n
);
dequant_2bit_16
(
load_int4
[
0
].
w
,
dq
[
3
],
size_n
);
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
dot22_16_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
0
],
qs_f0
);
block_c
[
m
][
1
]
=
dot22_16_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
1
],
qs_f1
);
block_c
[
m
][
2
]
=
dot22_16_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
2
],
qs_f2
);
block_c
[
m
][
3
]
=
dot22_16_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
,
block_c
[
m
][
3
],
qs_f3
);
}
a_ptr
+=
16
;
}
k
+=
32
;
}
// Accumulate column sums in c
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
0
]),
__float2half_rn
(
block_c
[
m
][
1
]));
half2
result23
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
2
]),
__float2half_rn
(
block_c
[
m
][
3
]));
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
}
}
fp_gemm_half_q_half_kernel
pick_gemm_half_q_half_kernel
(
bool
first_block
,
const
int
m_count
)
{
#if BLOCK_M_SIZE_MAX >= 1
if
(
m_count
==
1
)
return
gemm_half_q_half_kernel
<
true
,
1
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if
(
m_count
==
2
)
return
gemm_half_q_half_kernel
<
true
,
2
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_kernel
<
true
,
3
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_kernel
<
true
,
4
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_kernel
<
true
,
5
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_kernel
<
true
,
6
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_kernel
<
true
,
7
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_kernel
<
true
,
8
>
;
#endif
return
NULL
;
}
autogptq_extension/exllamav2/hip/q_gemm_kernel_gptq.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include "../hip/compat.cuh"
__forceinline__
__device__
half2
dot22_8
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half2
g_result
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hadd2
(
result
,
g_result
);
}
__forceinline__
__device__
float
dot22_8_f
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
}
typedef
void
(
*
fp_gemm_half_q_half_gptq_kernel
)
(
const
half
*
,
const
uint32_t
*
,
const
uint32_t
*
,
const
half
*
,
half
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
,
const
uint16_t
*
,
const
int
,
const
bool
);
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_gptq_kernel
(
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_n
,
const
int
size_k
,
const
int
groups
,
const
int
groupsize
,
const
uint16_t
*
__restrict__
b_q_perm
,
const
int
rows_4
,
const
bool
clear
)
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q4_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
{
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
;
if
(
b_q_perm
)
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
else
a0
=
a_ptr
[
offset_k
+
t
];
block_a_ptr
[
t
]
=
a0
;
}
}
// Zero output
if
(
n
>=
size_n
)
return
;
if
(
clear
&&
blockIdx
.
z
==
0
)
// && (threadIdx.x & 1) == 0)
{
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
*
((
uint64_t
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
))
=
0
;
}
__syncthreads
();
// Find initial group
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// a, b offset
int
qk
=
offset_k
/
(
32
/
4
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
// Initial group
int
zeros
[
4
];
float
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_f
(
scales
,
group
,
n
);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero
((
zeros
[
0
]
+
1
)
&
0x0f
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
((
zeros
[
1
]
+
1
)
&
0x0f
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
((
zeros
[
2
]
+
1
)
&
0x0f
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
((
zeros
[
3
]
+
1
)
&
0x0f
,
z1z16
[
3
],
y1y16
[
3
]);
// __syncthreads();
// Column result
float
block_c
[
m_count
][
4
]
=
{};
// Dequantize and multiply
int
k
=
offset_k
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_f
(
scales
,
group
,
n
);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero
((
zeros
[
0
]
+
1
)
&
0x0f
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
((
zeros
[
1
]
+
1
)
&
0x0f
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
((
zeros
[
2
]
+
1
)
&
0x0f
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
((
zeros
[
3
]
+
1
)
&
0x0f
,
z1z16
[
3
],
y1y16
[
3
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
half2
dq
[
4
][
4
];
dequant_4bit_8_gptq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
#pragma unroll
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
fma
(
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
),
scales
[
0
],
block_c
[
m
][
0
]);
block_c
[
m
][
1
]
=
fma
(
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
),
scales
[
1
],
block_c
[
m
][
1
]);
block_c
[
m
][
2
]
=
fma
(
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
),
scales
[
2
],
block_c
[
m
][
2
]);
block_c
[
m
][
3
]
=
fma
(
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
),
scales
[
3
],
block_c
[
m
][
3
]);
}
b_ptr
+=
size_n
;
a_ptr
+=
8
;
}
k
+=
32
;
}
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
0
]),
__float2half_rn
(
block_c
[
m
][
1
]));
half2
result23
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
2
]),
__float2half_rn
(
block_c
[
m
][
3
]));
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
}
}
fp_gemm_half_q_half_gptq_kernel
pick_gemm_half_q_half_gptq_kernel
(
bool
first_block
,
const
int
m_count
)
{
#if BLOCK_M_SIZE_MAX >= 1
if
(
m_count
==
1
)
return
gemm_half_q_half_gptq_kernel
<
true
,
1
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if
(
m_count
==
2
)
return
gemm_half_q_half_gptq_kernel
<
true
,
2
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_gptq_kernel
<
true
,
3
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_gptq_kernel
<
true
,
4
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_gptq_kernel
<
true
,
5
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_gptq_kernel
<
true
,
6
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_gptq_kernel
<
true
,
7
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_gptq_kernel
<
true
,
8
>
;
#endif
return
NULL
;
}
autogptq_extension/exllamav2/hip/q_matrix.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _q_matrix_cuh
#define _q_matrix_cuh
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdio>
#define MAX_SUPERGROUPS 16
class
QMatrix
{
public:
int
device
;
bool
is_gptq
;
int
height
;
int
width
;
int
groups
;
int
groupsize
;
int
rows_8
;
int
rows_6
;
int
rows_5
;
int
rows_4
;
int
rows_3
;
int
rows_2
;
uint32_t
*
cuda_q_weight
=
NULL
;
uint16_t
*
cuda_q_perm
=
NULL
;
uint16_t
*
cuda_q_invperm
=
NULL
;
uint32_t
*
cuda_q_scale
=
NULL
;
half
*
cuda_q_scale_max
=
NULL
;
uint16_t
*
cuda_q_groups
=
NULL
;
uint32_t
*
cuda_gptq_qzeros
=
NULL
;
half
*
cuda_gptq_scales
=
NULL
;
half
*
temp_dq
;
bool
failed
;
QMatrix
(
const
int
_device
,
const
int
_height
,
const
int
_width
,
const
int
_groups
,
uint32_t
*
_q_weight
,
uint16_t
*
_q_perm
,
uint16_t
*
_q_invperm
,
uint32_t
*
_q_scale
,
half
*
_q_scale_max
,
uint16_t
*
_q_groups
,
uint32_t
*
_gptq_qzeros
,
half
*
_gptq_scales
,
uint32_t
*
_gptq_g_idx
,
half
*
_temp_dq
);
~
QMatrix
();
void
reconstruct
(
half
*
out
);
bool
make_sequential
(
const
uint32_t
*
cpu_g_idx
);
private:
};
#endif
autogptq_extension/exllamav2/hip/q_matrix.hip
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include "../hip/q_matrix.cuh"
#include "../hip/matrix_view.cuh"
#include "../hip/util.cuh"
#include "../hip/quant/qdq_2.cuh"
#include "../hip/quant/qdq_3.cuh"
#include "../hip/quant/qdq_4.cuh"
#include "../hip/quant/qdq_5.cuh"
#include "../hip/quant/qdq_6.cuh"
#include "../hip/quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define THREADS_X 32
#define THREADS_Y 32
// Shuffle quantized data on load
__global__ void shuffle_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
}
// QMatrix constructor
QMatrix::QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
) :
device(_device),
height(_height),
width(_width),
groups(_groups),
temp_dq(_temp_dq)
{
hipSetDevice(device);
failed = false;
cuda_q_weight = _q_weight;
cuda_q_perm = _q_perm;
cuda_q_invperm = _q_invperm;
cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups;
cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL);
groupsize = 1;
while (groupsize * groups < height) groupsize *= 2;
// Create group map
rows_8 = 0;
rows_6 = 0;
rows_5 = 0;
rows_4 = 0;
rows_3 = 0;
rows_2 = 0;
if (!is_gptq)
{
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
hipMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), hipMemcpyDeviceToHost);
for (int i = 0; i < groups; i++)
{
int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize;
if (bits == 5) rows_5 += groupsize;
if (bits == 4) rows_4 += groupsize;
if (bits == 3) rows_3 += groupsize;
if (bits == 2) rows_2 += groupsize;
}
free(cpu_q_groups);
rows_6 += rows_8;
rows_5 += rows_6;
rows_4 += rows_5;
rows_3 += rows_4;
rows_2 += rows_3;
}
else
{
rows_4 = height;
rows_3 = height;
rows_2 = height;
if (_gptq_g_idx)
{
if (!make_sequential(_gptq_g_idx))
{
failed = true;
//printf("FAIL\n");
return;
}
}
}
// Shuffle quantized data
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
hipLaunchKernelGGL(( shuffle_kernel), dim3(gridDim), dim3(blockDim), 0, 0, cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}
QMatrix::~QMatrix()
{
}
// Reconstruct b[k,n] (GPTQ)
__global__ void reconstruct_gptq_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_4
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ uint16_t perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
// Reconstruct b[k,n]
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
// Preload remapping table
int t = threadIdx.x;
__shared__ uint16_t perm[BLOCK_KN_SIZE];
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
// Column
int n = offset_n + t;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k;
int lk = 0;
__syncthreads();
while (k < rows_8 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
dequant_8bit_8(q_0, q_1, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
uint32_t q_3 = *b_ptr; b_ptr += size_n;
uint32_t q_4 = *b_ptr; b_ptr += size_n;
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_4bit_8(q_0, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_2bit_16(q_0, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
}
void QMatrix::reconstruct(half* out)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
if (!is_gptq)
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
hipLaunchKernelGGL(( reconstruct_kernel), dim3(gridDim), dim3(blockDim), 0, 0,
cuda_q_weight,
cuda_q_perm,
cuda_q_scale,
cuda_q_scale_max,
//cuda_q_groups,
height,
width,
groupsize,
groups,
out,
rows_8,
rows_6,
rows_5,
rows_4,
rows_3,
rows_2
);
}
else
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
hipLaunchKernelGGL(( reconstruct_gptq_kernel), dim3(gridDim), dim3(blockDim), 0, 0,
cuda_q_weight,
cuda_q_perm,
cuda_gptq_qzeros,
cuda_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
height,
width,
groupsize,
groups,
out,
rows_4
);
}
}
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint16_t* __restrict__ q_perm,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
hipError_t err = hipMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
if (err != hipSuccess) {
hipError_t cuda_status = hipGetLastError(); // Clear error
return false;
}
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Reduce to uint16_t
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
// Move to CUDA
hipMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), hipMemcpyHostToDevice);
hipMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), hipMemcpyHostToDevice);
// Rearrange rows in w
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
hipLaunchKernelGGL(( make_sequential_kernel), dim3(gridDim), dim3(blockDim), 0, 0,
cuda_q_weight,
cuda_new_qweight,
cuda_q_perm,
height / 8,
width
);
// Replace qweights
hipMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), hipMemcpyDeviceToDevice);
// Cleanup
hipDeviceSynchronize();
hipFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
return true;
}
autogptq_extension/exllamav2/hip/quant/qdq_2.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_2BIT == 1
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__
__device__
void
shuffle_2bit_16
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x03
;
uint32_t
qa1
=
(
qa
&
0x0c
)
>>
2
;
qa
>>=
4
;
qb
|=
(
qa1
<<
(
i
*
2
+
16
));
qb
|=
(
qa0
<<
(
i
*
2
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_2bit_16
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
8
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y4_
=
__float2half_rn
(
1.0
f
/
4.0
f
);
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y4
=
__halves2half2
(
y4_
,
y4_
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
2.0
f
);
const
half
z4_
=
__float2half_rn
(
-
1024.0
f
/
4.0
f
-
2.0
f
);
const
half
z16_
=
__float2half_rn
(
-
1024.0
f
/
16.0
f
-
2.0
f
);
const
half
z64_
=
__float2half_rn
(
-
1024.0
f
/
64.0
f
-
2.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z4
=
__halves2half2
(
z4_
,
z4_
);
const
half2
z16
=
__halves2half2
(
z16_
,
z16_
);
const
half2
z64
=
__halves2half2
(
z64_
,
z64_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32
q2
((
qa
&
0x00300030
)
|
c0
);
// half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32
q3
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 64 + 1024
qa
>>=
8
;
half2_uint32
q4
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 8], q[ 8]) + 1024
half2_uint32
q5
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[10], q[11]) * 4 + 1024
half2_uint32
q6
((
qa
&
0x00300030
)
|
c0
);
// half2(q[12], q[13]) * 16 + 1024
half2_uint32
q7
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[14], q[15]) * 64 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y4
,
z4
);
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y16
,
z16
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y64
,
z64
);
dq
[
4
]
=
__hadd2
(
q4
.
as_half2
,
z1
);
dq
[
5
]
=
__hfma2
(
q5
.
as_half2
,
y4
,
z4
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y16
,
z16
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y64
,
z64
);
}
#else
__forceinline__
__device__
void
shuffle_2bit_16
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_2bit_16
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
8
],
int
stride
)
{
half
dqh
[
16
];
for
(
int
i
=
0
;
i
<
16
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
2
,
0x03
),
2
);
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
\ No newline at end of file
autogptq_extension/exllamav2/hip/quant/qdq_3.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_3BIT == 1
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__
__device__
void
shuffle_3bit_32
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
*
stride
];
uint32_t
qb
=
q
[
1
*
stride
];
uint32_t
qc
=
q
[
2
*
stride
];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t
qd
=
qc
>>
26
;
qc
<<=
4
;
qc
|=
qb
>>
28
;
qb
<<=
2
;
qb
|=
qa
>>
30
;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t
za
=
0
;
uint32_t
zb
=
0
;
uint32_t
zc
=
0
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qa
&
0x07
;
uint32_t
t1
=
(
qa
&
0x38
)
>>
3
;
qa
>>=
6
;
za
|=
(
t0
<<
(
i
*
3
));
za
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qb
&
0x07
;
uint32_t
t1
=
(
qb
&
0x38
)
>>
3
;
qb
>>=
6
;
zb
|=
(
t0
<<
(
i
*
3
));
zb
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qc
&
0x07
;
uint32_t
t1
=
(
qc
&
0x38
)
>>
3
;
qc
>>=
6
;
zc
|=
(
t0
<<
(
i
*
3
));
zc
|=
(
t1
<<
(
i
*
3
+
16
));
}
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za
|=
((
qd
&
0x01
)
>>
0
)
<<
15
;
zb
|=
((
qd
&
0x02
)
>>
1
)
<<
15
;
zc
|=
((
qd
&
0x04
)
>>
2
)
<<
15
;
za
|=
((
qd
&
0x08
)
>>
3
)
<<
31
;
zb
|=
((
qd
&
0x10
)
>>
4
)
<<
31
;
zc
|=
((
qd
&
0x20
)
>>
5
)
<<
31
;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q
[
0
*
stride
]
=
za
;
q
[
1
*
stride
]
=
zb
;
q
[
2
*
stride
]
=
zc
;
}
__forceinline__
__device__
void
dequant_3bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
16
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y8_
=
__float2half_rn
(
1.0
f
/
8.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y8
=
__halves2half2
(
y8_
,
y8_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
4.0
f
);
const
half
z8_
=
__float2half_rn
(
-
1024.0
f
/
8.0
f
-
4.0
f
);
const
half
z64_
=
__float2half_rn
(
-
1024.0
f
/
64.0
f
-
4.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z8
=
__halves2half2
(
z8_
,
z8_
);
const
half2
z64
=
__halves2half2
(
z64_
,
z64_
);
uint32_t
qa
=
q_0
;
uint32_t
qb
=
q_1
;
uint32_t
qc
=
q_2
;
half2_uint32
q0
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 2], q[ 3]) * 8 + 1024
qa
>>=
6
;
half2_uint32
q2
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32
q4
((
qa
&
0x01c001c0
)
|
c0
);
// half2(q[ 8], q[ 9]) * 64 + 1024
qa
>>=
9
;
qa
&=
0x00010001
;
half2_uint32
q5
((
qb
&
0x00070007
)
|
c0
);
// half2(q[10], q[11]) + 1024
half2_uint32
q6
((
qb
&
0x00380038
)
|
c0
);
// half2(q[12], q[13]) * 8 + 1024
qb
>>=
6
;
half2_uint32
q7
((
qb
&
0x00070007
)
|
c0
);
// half2(q[14], q[15]) + 1024
half2_uint32
q8
((
qb
&
0x00380038
)
|
c0
);
// half2(q[16], q[17]) * 8 + 1024
half2_uint32
q9
((
qb
&
0x01c001c0
)
|
c0
);
// half2(q[18], q[19]) * 64 + 1024
qb
>>=
8
;
qb
&=
0x00020002
;
half2_uint32
q10
((
qc
&
0x00070007
)
|
c0
);
// half2(q[20], q[21]) + 1024
half2_uint32
q11
((
qc
&
0x00380038
)
|
c0
);
// half2(q[22], q[23]) * 8 + 1024
qc
>>=
6
;
half2_uint32
q12
((
qc
&
0x00070007
)
|
c0
);
// half2(q[24], q[25]) + 1024
half2_uint32
q13
((
qc
&
0x00380038
)
|
c0
);
// half2(q[26], q[27]) * 8 + 1024
half2_uint32
q14
((
qc
&
0x01c001c0
)
|
c0
);
// half2(q[28], q[29]) * 64 + 1024
qc
>>=
7
;
qc
&=
0x00040004
;
half2_uint32
q15
((
qa
|
qb
|
qc
)
|
c0
);
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y8
,
z8
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y8
,
z8
);
dq
[
4
]
=
__hfma2
(
q4
.
as_half2
,
y64
,
z64
);
dq
[
5
]
=
__hadd2
(
q5
.
as_half2
,
z1
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y8
,
z8
);
dq
[
7
]
=
__hadd2
(
q7
.
as_half2
,
z1
);
dq
[
8
]
=
__hfma2
(
q8
.
as_half2
,
y8
,
z8
);
dq
[
9
]
=
__hfma2
(
q9
.
as_half2
,
y64
,
z64
);
dq
[
10
]
=
__hadd2
(
q10
.
as_half2
,
z1
);
dq
[
11
]
=
__hfma2
(
q11
.
as_half2
,
y8
,
z8
);
dq
[
12
]
=
__hadd2
(
q12
.
as_half2
,
z1
);
dq
[
13
]
=
__hfma2
(
q13
.
as_half2
,
y8
,
z8
);
dq
[
14
]
=
__hfma2
(
q14
.
as_half2
,
y64
,
z64
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
#else
__forceinline__
__device__
void
shuffle_3bit_32
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_3bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
16
],
int
stride
)
{
half
dqh
[
32
];
for
(
int
i
=
0
;
i
<
10
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
3
,
0x07
),
4
);
dqh
[
10
]
=
dq_ns
(
exb
(
q_1
,
q_0
,
30
,
0x07
),
4
);
for
(
int
i
=
0
;
i
<
10
;
i
++
)
dqh
[
11
+
i
]
=
dq_ns
(
exb
(
q_1
,
i
*
3
+
1
,
0x07
),
4
);
dqh
[
21
]
=
dq_ns
(
exb
(
q_2
,
q_1
,
31
,
0x07
),
4
);
for
(
int
i
=
0
;
i
<
10
;
i
++
)
dqh
[
22
+
i
]
=
dq_ns
(
exb
(
q_2
,
i
*
3
+
2
,
0x07
),
4
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
autogptq_extension/exllamav2/hip/quant/qdq_4.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_4BIT == 1
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x0f
;
uint32_t
qa1
=
(
qa
&
0xf0
)
>>
4
;
qa
>>=
8
;
qb
|=
(
qa1
<<
(
i
*
4
+
16
));
qb
|=
(
qa0
<<
(
i
*
4
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
8.0
f
);
const
half
z16_
=
__float2half_rn
(
-
1024.0
f
/
16.0
f
-
8.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z16
=
__halves2half2
(
z16_
,
z16_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 2], q[ 3]) * 16 + 1024
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 16 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y16
,
z16
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y16
,
z16
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
half2
scale2
=
__half2half2
(
scale
);
z1z16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
z1
.
as_half
));
z1z16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
z16
));
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
y1
));
y1y16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
y16
));
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
z1z16
[
0
]
=
__half2half2
(
z1
.
as_half
);
z1z16
[
1
]
=
__half2half2
(
z16
);
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__half2half2
(
y1
);
y1y16
[
1
]
=
__half2half2
(
y16
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
],
int
stride
,
bool
scaled
)
{
const
uint32_t
c0
=
0x64006400
;
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[0] + 1024, q[1] + 1024 )
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[4] + 1024, q[5] + 1024 )
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
q0
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
// half2( q[0] * s - z * s, q[1] * s - z * s)
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] * s - z * s, q[3] * s - z * s)
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
}
else
{
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1z16
[
0
]);
// half2( q[0] - z, q[1] - z )
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] - z, q[3] - z )
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1z16
[
0
]);
// half2( q[4] - z, q[5] - z )
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[6] - z, q[7] - z )
}
}
#else
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
4
,
0x0f
),
8
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z
=
__hmul
(
z
,
scale
);
z1
[
0
]
=
__half2half2
(
z
);
y1
[
0
]
=
__half2half2
(
scale
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z1
[
0
]
=
__half2half2
(
z
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
],
int
stride
,
bool
scaled
)
{
half2
dqh2
[
8
];
uint32_t
qa
=
q_0
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
half
d0
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
half
d1
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
dqh2
[
i
]
=
__halves2half2
(
d0
,
d1
);
}
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
dqh2
[
0
],
y1
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hfma2
(
dqh2
[
1
],
y1
[
0
],
z1
[
0
]);
dq
[
2
]
=
__hfma2
(
dqh2
[
2
],
y1
[
0
],
z1
[
0
]);
dq
[
3
]
=
__hfma2
(
dqh2
[
3
],
y1
[
0
],
z1
[
0
]);
}
else
{
dq
[
0
]
=
__hadd2
(
dqh2
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hadd2
(
dqh2
[
1
],
z1
[
0
]);
dq
[
2
]
=
__hadd2
(
dqh2
[
2
],
z1
[
0
]);
dq
[
3
]
=
__hadd2
(
dqh2
[
3
],
z1
[
0
]);
}
}
#endif
#endif
\ No newline at end of file
autogptq_extension/exllamav2/hip/quant/qdq_5.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_5_cuh
#define _qdq_5_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_5BIT == 1
// Permutation:
//
// v5555533 33311111 u4444422 22200000 (u, v lsb)
// vbbbbb99 99977777 uaaaaa88 88866666
// vhhhhhff fffddddd ugggggee eeeccccc
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
// vtttttrr rrrppppp usssssqq qqqooooo
__forceinline__
__device__
void
shuffle_5bit_32
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
*
stride
];
uint32_t
qb
=
q
[
1
*
stride
];
uint32_t
qc
=
q
[
2
*
stride
];
uint32_t
qd
=
q
[
3
*
stride
];
uint32_t
qe
=
q
[
4
*
stride
];
// qa: 66555554 44443333 32222211 11100000
// qb: ccccbbbb baaaaa99 99988888 77777666
// qc: jiiiiihh hhhggggg fffffeee eedddddc
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
uint32_t
qf
=
qe
>>
22
;
qe
<<=
8
;
qe
|=
qd
>>
24
;
qd
<<=
6
;
qd
|=
qc
>>
26
;
qc
<<=
4
;
qc
|=
qb
>>
28
;
qb
<<=
2
;
qb
|=
qa
>>
30
;
// qa: 555554 44443333 32222211 11100000
// qb: bbbbba aaaa9999 98888877 77766666
// qc: hhhhhg ggggffff feeeeedd dddccccc
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
// qf: vv vvvuuuuu
uint32_t
za
=
0
;
uint32_t
zb
=
0
;
uint32_t
zc
=
0
;
uint32_t
zd
=
0
;
uint32_t
ze
=
0
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qa
&
0x1f
;
uint32_t
t1
=
(
qa
&
0x3e0
)
>>
5
;
qa
>>=
10
;
za
|=
(
t0
<<
(
i
*
5
));
za
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qb
&
0x1f
;
uint32_t
t1
=
(
qb
&
0x3e0
)
>>
5
;
qb
>>=
10
;
zb
|=
(
t0
<<
(
i
*
5
));
zb
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qc
&
0x1f
;
uint32_t
t1
=
(
qc
&
0x3e0
)
>>
5
;
qc
>>=
10
;
zc
|=
(
t0
<<
(
i
*
5
));
zc
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qd
&
0x1f
;
uint32_t
t1
=
(
qd
&
0x3e0
)
>>
5
;
qd
>>=
10
;
zd
|=
(
t0
<<
(
i
*
5
));
zd
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qe
&
0x1f
;
uint32_t
t1
=
(
qe
&
0x3e0
)
>>
5
;
qe
>>=
10
;
ze
|=
(
t0
<<
(
i
*
5
));
ze
|=
(
t1
<<
(
i
*
5
+
16
));
}
// za: 5555533 33311111 4444422 22200000
// zb: bbbbb99 99977777 aaaaa88 88866666
// zc: hhhhhff fffddddd gggggee eeeccccc
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
// ze: tttttrr rrrppppp sssssqq qqqooooo
// qf: vv vvvuuuuu
za
|=
((
qf
&
0x001
)
>>
0
)
<<
15
;
zb
|=
((
qf
&
0x002
)
>>
1
)
<<
15
;
zc
|=
((
qf
&
0x004
)
>>
2
)
<<
15
;
zd
|=
((
qf
&
0x008
)
>>
3
)
<<
15
;
ze
|=
((
qf
&
0x010
)
>>
4
)
<<
15
;
za
|=
((
qf
&
0x020
)
>>
5
)
<<
31
;
zb
|=
((
qf
&
0x040
)
>>
6
)
<<
31
;
zc
|=
((
qf
&
0x080
)
>>
7
)
<<
31
;
zd
|=
((
qf
&
0x100
)
>>
8
)
<<
31
;
ze
|=
((
qf
&
0x200
)
>>
9
)
<<
31
;
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
// zb: vbbbbb99 99977777 uaaaaa88 88866666
// zc: vhhhhhff fffddddd ugggggee eeeccccc
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
// ze: vtttttrr rrrppppp usssssqq qqqooooo
q
[
0
*
stride
]
=
za
;
q
[
1
*
stride
]
=
zb
;
q
[
2
*
stride
]
=
zc
;
q
[
3
*
stride
]
=
zd
;
q
[
4
*
stride
]
=
ze
;
}
__forceinline__
__device__
void
dequant_5bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
const
uint32_t
q_3
,
const
uint32_t
q_4
,
half2
(
&
dq
)[
16
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y32_
=
__float2half_rn
(
1.0
f
/
32.0
f
);
const
half2
y32
=
__halves2half2
(
y32_
,
y32_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
16.0
f
);
const
half
z32_
=
__float2half_rn
(
-
1024.0
f
/
32.0
f
-
16.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z32
=
__halves2half2
(
z32_
,
z32_
);
uint32_t
qa
=
q_0
;
uint32_t
qb
=
q_1
;
uint32_t
qc
=
q_2
;
uint32_t
qd
=
q_3
;
uint32_t
qe
=
q_4
;
half2_uint32
q0
((
qa
&
0x001f001f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x03e003e0
)
|
c0
);
// half2(q[ 2], q[ 3]) * 32 + 1024
qa
>>=
10
;
half2_uint32
q2
((
qa
&
0x001f001f
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
qa
>>=
5
;
qa
&=
0x00010001
;
half2_uint32
q3
((
qb
&
0x001f001f
)
|
c0
);
// half2(q[ 6], q[ 7]) + 1024
half2_uint32
q4
((
qb
&
0x03e003e0
)
|
c0
);
// half2(q[ 8], q[ 9]) * 32 + 1024
qb
>>=
10
;
half2_uint32
q5
((
qb
&
0x001f001f
)
|
c0
);
// half2(q[10], q[11]) + 1024
qb
>>=
4
;
qb
&=
0x00020002
;
half2_uint32
q6
((
qc
&
0x001f001f
)
|
c0
);
// half2(q[12], q[13]) + 1024
half2_uint32
q7
((
qc
&
0x03e003e0
)
|
c0
);
// half2(q[14], q[15]) * 32 + 1024
qc
>>=
10
;
half2_uint32
q8
((
qc
&
0x001f001f
)
|
c0
);
// half2(q[16], q[17]) + 1024
qc
>>=
3
;
qc
&=
0x00040004
;
half2_uint32
q9
((
qd
&
0x001f001f
)
|
c0
);
// half2(q[18], q[19]) + 1024
half2_uint32
q10
((
qd
&
0x03e003e0
)
|
c0
);
// half2(q[20], q[21]) * 32 + 1024
qd
>>=
10
;
half2_uint32
q11
((
qd
&
0x001f001f
)
|
c0
);
// half2(q[22], q[23]) + 1024
qd
>>=
2
;
qd
&=
0x00080008
;
half2_uint32
q12
((
qe
&
0x001f001f
)
|
c0
);
// half2(q[24], q[25]) + 1024
half2_uint32
q13
((
qe
&
0x03e003e0
)
|
c0
);
// half2(q[26], q[27]) * 32 + 1024
qe
>>=
10
;
half2_uint32
q14
((
qe
&
0x001f001f
)
|
c0
);
// half2(q[28], q[29]) + 1024
qe
>>=
1
;
qe
&=
0x00100010
;
half2_uint32
q15
((
qa
|
qb
|
qc
|
qd
|
qe
)
|
c0
);
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y32
,
z32
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hadd2
(
q3
.
as_half2
,
z1
);
dq
[
4
]
=
__hfma2
(
q4
.
as_half2
,
y32
,
z32
);
dq
[
5
]
=
__hadd2
(
q5
.
as_half2
,
z1
);
dq
[
6
]
=
__hadd2
(
q6
.
as_half2
,
z1
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y32
,
z32
);
dq
[
8
]
=
__hadd2
(
q8
.
as_half2
,
z1
);
dq
[
9
]
=
__hadd2
(
q9
.
as_half2
,
z1
);
dq
[
10
]
=
__hfma2
(
q10
.
as_half2
,
y32
,
z32
);
dq
[
11
]
=
__hadd2
(
q11
.
as_half2
,
z1
);
dq
[
12
]
=
__hadd2
(
q12
.
as_half2
,
z1
);
dq
[
13
]
=
__hfma2
(
q13
.
as_half2
,
y32
,
z32
);
dq
[
14
]
=
__hadd2
(
q14
.
as_half2
,
z1
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
#else
__forceinline__
__device__
void
shuffle_5bit_32
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_5bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
const
uint32_t
q_3
,
const
uint32_t
q_4
,
half2
(
&
dq
)[
16
],
int
stride
)
{
half
dqh
[
32
];
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
5
,
0x1f
),
16
);
dqh
[
6
]
=
dq_ns
(
exb
(
q_1
,
q_0
,
30
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
7
+
i
]
=
dq_ns
(
exb
(
q_1
,
i
*
5
+
3
,
0x1f
),
16
);
dqh
[
12
]
=
dq_ns
(
exb
(
q_2
,
q_1
,
28
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
13
+
i
]
=
dq_ns
(
exb
(
q_2
,
i
*
5
+
1
,
0x1f
),
16
);
dqh
[
19
]
=
dq_ns
(
exb
(
q_3
,
q_2
,
31
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
20
+
i
]
=
dq_ns
(
exb
(
q_3
,
i
*
5
+
4
,
0x1f
),
16
);
dqh
[
25
]
=
dq_ns
(
exb
(
q_4
,
q_3
,
29
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
26
+
i
]
=
dq_ns
(
exb
(
q_4
,
i
*
5
+
2
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
\ No newline at end of file
autogptq_extension/exllamav2/hip/quant/qdq_6.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_6_cuh
#define _qdq_6_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_6BIT == 1
// Not implemented
#else
__forceinline__
__device__
void
shuffle_6bit_16
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_6bit_16
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
8
],
int
stride
)
{
half
dqh
[
16
];
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
6
,
0x3f
),
32
);
dqh
[
5
]
=
dq_ns
(
exb
(
q_1
,
q_0
,
30
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
6
+
i
]
=
dq_ns
(
exb
(
q_1
,
i
*
6
+
4
,
0x3f
),
32
);
dqh
[
10
]
=
dq_ns
(
exb
(
q_2
,
q_1
,
28
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
11
+
i
]
=
dq_ns
(
exb
(
q_2
,
i
*
6
+
2
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
autogptq_extension/exllamav2/hip/quant/qdq_8.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_8BIT == 1
// Not implemented
#else
__forceinline__
__device__
void
shuffle_8bit_4
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_8bit_8
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
8
,
0xff
),
128
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
+
4
]
=
dq_ns
(
exb
(
q_1
,
i
*
8
,
0xff
),
128
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
\ No newline at end of file
autogptq_extension/exllamav2/hip/quant/qdq_util.cuh
0 → 100644
View file @
da900c3b
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
union
half2_uint32
{
uint32_t
as_uint32
;
half2
as_half2
;
__device__
half2_uint32
(
uint32_t
val
)
:
as_uint32
(
val
)
{}
__device__
half2_uint32
(
half2
val
)
:
as_half2
(
val
)
{}
};
union
half_uint16
{
uint16_t
as_uint16
;
half
as_half
;
__device__
half_uint16
(
uint16_t
val
)
:
as_uint16
(
val
)
{}
__device__
half_uint16
(
half
val
)
:
as_half
(
val
)
{}
};
// Max_scale premultiplied by 1/256
__forceinline__
__device__
half
dq_scale
(
const
int
qs
,
const
half
max_scale
)
{
int
qs_i
=
qs
+
1
;
half
qs_h
=
__int2half_rn
(
qs_i
*
qs_i
);
qs_h
=
__hmul
(
qs_h
,
max_scale
);
return
qs_h
;
}
__forceinline__
__device__
half
dq
(
const
int
q
,
const
int
qzero
,
const
half
scale
)
{
return
__hmul
(
__int2half_rn
(
q
-
qzero
),
scale
);
}
__forceinline__
__device__
half
dq_ns
(
const
int
q
,
const
int
qzero
)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return
__int2half_rn
(
q
-
qzero
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)((
q
>>
shift
)
&
mask
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q1
,
const
uint32_t
q0
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)(
__funnelshift_rc
(
q0
,
q1
,
shift
)
&
mask
);
}
#endif
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment