Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
5f492d43
Commit
5f492d43
authored
Jul 10, 2023
by
Tim Dettmers
Browse files
Merge remote-tracking branch 'origin/inference'
parents
196d6f5d
5fab6734
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
522 additions
and
306 deletions
+522
-306
Makefile
Makefile
+2
-2
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+6
-3
bitsandbytes/functional.py
bitsandbytes/functional.py
+109
-112
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+3
-3
csrc/kernels.cu
csrc/kernels.cu
+216
-28
csrc/kernels.cuh
csrc/kernels.cuh
+2
-0
csrc/ops.cu
csrc/ops.cu
+29
-7
csrc/ops.cuh
csrc/ops.cuh
+1
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+49
-11
tests/test_functional.py
tests/test_functional.py
+104
-140
tests/test_modules.py
tests/test_modules.py
+1
-0
No files found.
Makefile
View file @
5f492d43
...
@@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
...
@@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110
+=
-gencode
arch
=
compute_80,code
=
sm_80
CC_cublasLt110
+=
-gencode
arch
=
compute_80,code
=
sm_80
CC_cublasLt111
:=
-gencode
arch
=
compute_75,code
=
sm_75
CC_cublasLt111
:=
-gencode
arch
=
compute_75,code
=
sm_75
CC_cublasLt111
+=
-gencode
arch
=
compute_80,code
=
sm_80
#
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111
+=
-gencode
arch
=
compute_86,code
=
sm_86
#
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
CC_ADA_HOPPER
:=
-gencode
arch
=
compute_89,code
=
sm_89
CC_ADA_HOPPER
:=
-gencode
arch
=
compute_89,code
=
sm_89
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
...
...
bitsandbytes/autograd/_functions.py
View file @
5f492d43
...
@@ -512,7 +512,7 @@ class MatMul4Bit(torch.autograd.Function):
...
@@ -512,7 +512,7 @@ class MatMul4Bit(torch.autograd.Function):
# 1. Dequantize
# 1. Dequantize
# 2. MatmulnN
# 2. MatmulnN
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_
fp4
(
B
,
state
).
to
(
A
.
dtype
).
t
(),
bias
)
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_
4bit
(
B
,
state
).
to
(
A
.
dtype
).
t
(),
bias
)
# 3. Save state
# 3. Save state
ctx
.
state
=
state
ctx
.
state
=
state
...
@@ -543,7 +543,7 @@ class MatMul4Bit(torch.autograd.Function):
...
@@ -543,7 +543,7 @@ class MatMul4Bit(torch.autograd.Function):
# not supported by PyTorch. TODO: create work-around
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_
fp4
(
B
,
ctx
.
state
).
to
(
grad_output
.
dtype
).
t
())
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_
4bit
(
B
,
ctx
.
state
).
to
(
grad_output
.
dtype
).
t
())
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
@@ -564,4 +564,7 @@ def matmul(
...
@@ -564,4 +564,7 @@ def matmul(
def
matmul_4bit
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
def
matmul_4bit
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
assert
quant_state
is
not
None
assert
quant_state
is
not
None
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
if
A
.
numel
()
==
A
.
shape
[
-
1
]
and
A
.
requires_grad
==
False
:
return
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
else
:
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
5f492d43
...
@@ -240,17 +240,19 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
...
@@ -240,17 +240,19 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
v1
=
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
9
)[:
-
1
]).
tolist
()
v1
=
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
9
)[:
-
1
]).
tolist
()
v2
=
[
0
]
*
(
256
-
15
)
## we have 15 non-zero values in this data type
v2
=
[
0
]
*
(
256
-
15
)
## we have 15 non-zero values in this data type
v3
=
(
-
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
])).
tolist
()
v3
=
(
-
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
])).
tolist
()
v
=
v1
+
v2
+
v3
else
:
else
:
v1
=
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
]).
tolist
()
v1
=
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
]).
tolist
()
v2
=
[
0
]
*
(
256
-
14
)
## we have 14 non-zero values in this data type
v2
=
[
0
]
*
(
256
-
14
)
## we have 14 non-zero values in this data type
v3
=
(
-
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
])).
tolist
()
v3
=
(
-
norm
.
ppf
(
torch
.
linspace
(
offset
,
0.5
,
8
)[:
-
1
])).
tolist
()
v
=
v1
+
v2
+
v3
v
=
v1
+
v2
+
v3
values
=
torch
.
Tensor
(
v
)
values
=
torch
.
Tensor
(
v
)
values
=
values
.
sort
().
values
values
=
values
.
sort
().
values
values
/=
values
.
max
()
values
/=
values
.
max
()
assert
values
.
numel
()
==
256
assert
values
.
numel
()
==
256
return
values
return
values
def
create_fp8_map
(
signed
=
True
,
exponent_bits
=
5
,
precision_bits
=
2
,
total_bits
=
8
):
def
create_fp8_map
(
signed
=
True
,
exponent_bits
=
5
,
precision_bits
=
2
,
total_bits
=
8
):
...
@@ -617,6 +619,8 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
...
@@ -617,6 +619,8 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
elif
A
.
dtype
==
torch
.
float16
:
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cquantize_blockwise_bf16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
cblocksize
,
ct
.
c_int
(
A
.
numel
()))
else
:
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
post_call
(
A
.
device
)
...
@@ -629,11 +633,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
...
@@ -629,11 +633,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ou
offset
=
absmax
.
mean
()
offset
=
absmax
.
mean
()
absmax
-=
offset
absmax
-=
offset
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
blocksize
,
nested
=
False
)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
blocksize
,
nested
=
False
)
state
=
[
qabsmax
,
code
,
blocksize
,
nested
,
offset
,
state2
]
state
=
[
qabsmax
,
code
,
blocksize
,
nested
,
A
.
dtype
,
offset
,
state2
]
else
:
else
:
state
=
[
absmax
,
code
,
blocksize
,
nested
,
None
,
None
]
state
=
[
absmax
,
code
,
blocksize
,
nested
,
A
.
dtype
,
None
,
None
]
return
out
,
state
return
out
,
state
...
@@ -678,18 +680,16 @@ def dequantize_blockwise(
...
@@ -678,18 +680,16 @@ def dequantize_blockwise(
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
name2qmap
[
"dynamic"
]
=
create_dynamic_map
().
to
(
A
.
device
)
code
=
name2qmap
[
"dynamic"
]
code
=
name2qmap
[
"dynamic"
]
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
if
quant_state
is
None
:
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
,
blocksize
)
quant_state
=
(
absmax
,
code
,
blocksize
,
False
,
torch
.
float32
,
None
,
None
)
assert
absmax
is
not
None
and
out
is
not
None
else
:
absmax
,
code
,
blocksize
,
nested
,
offset
,
state2
=
quant_state
if
nested
:
absmax
=
dequantize_blockwise
(
absmax
,
state2
)
absmax
+=
offset
absmax
,
code
,
blocksize
,
nested
,
dtype
,
offset
,
state2
=
quant_state
if
nested
:
absmax
=
dequantize_blockwise
(
absmax
,
state2
)
absmax
+=
offset
if
out
is
None
:
out
=
torch
.
empty
(
A
.
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
if
A
.
device
.
type
!=
'cpu'
:
if
A
.
device
.
type
!=
'cpu'
:
device
=
pre_call
(
A
.
device
)
device
=
pre_call
(
A
.
device
)
...
@@ -701,6 +701,8 @@ def dequantize_blockwise(
...
@@ -701,6 +701,8 @@ def dequantize_blockwise(
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
elif
out
.
dtype
==
torch
.
float16
:
elif
out
.
dtype
==
torch
.
float16
:
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cdequantize_blockwise_fp16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
elif
out
.
dtype
==
torch
.
bfloat16
:
lib
.
cdequantize_blockwise_bf16
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
A
.
numel
()))
else
:
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
post_call
(
A
.
device
)
...
@@ -710,6 +712,47 @@ def dequantize_blockwise(
...
@@ -710,6 +712,47 @@ def dequantize_blockwise(
return
out
return
out
def
get_4bit_type
(
typename
,
device
=
None
,
blocksize
=
64
):
if
device
is
None
:
device
=
'cuda'
data
=
None
if
typename
==
'nf4'
:
data
=
[
-
1.0
,
-
0.6961928009986877
,
-
0.5250730514526367
,
-
0.39491748809814453
,
-
0.28444138169288635
,
-
0.18477343022823334
,
-
0.09105003625154495
,
0.0
,
0.07958029955625534
,
0.16093020141124725
,
0.24611230194568634
,
0.33791524171829224
,
0.44070982933044434
,
0.5626170039176941
,
0.7229568362236023
,
1.0
]
elif
typename
==
'fp4'
:
# 0b000 = 0
# 0b001 = 0.0625
# 0b010 = 8
# 0b011 = 12
# 0b100 = 4
# 0b101 = 6
# 0b110 = 2
# 0b111 = 3
data
=
[
0
,
0.0625
,
8.0
,
12.0
,
4.0
,
6.0
,
2.0
,
3.0
,
-
0
,
-
0.0625
,
-
8.0
,
-
12.0
,
-
4.0
,
-
6.0
,
-
2.0
,
-
3.0
]
elif
typename
==
'int4'
:
data
=
[
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
,
-
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
-
6
,
-
7
]
elif
typename
==
'af4'
:
# Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
# https://arxiv.org/abs/2306.06965
if
blocksize
==
64
:
data
=
[
-
1.
,
-
0.69441008
,
-
0.51243739
,
-
0.3736951
,
-
0.25607552
,
-
0.14982478
,
-
0.04934812
,
0.
,
0.04273164
,
0.12934483
,
0.21961274
,
0.31675666
,
0.42563882
,
0.55496234
,
0.72424863
,
1.
][::
-
1
]
else
:
raise
NotImplementedError
(
f
'4-bit AbnormalFloats currently only support blocksize 64.'
)
if
data
is
None
:
raise
NotImplementedError
(
f
'Typename
{
typename
}
not supported'
)
data
=
Tensor
(
data
)
data
/=
data
.
abs
().
max
()
assert
data
.
numel
()
==
16
return
data
.
to
(
device
)
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
return
quantize_4bit
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
...
@@ -774,20 +817,25 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
...
@@ -774,20 +817,25 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
lib
.
cquantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
lib
.
cquantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
else
:
lib
.
cquantize_blockwise_fp16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
lib
.
cquantize_blockwise_fp16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
elif
A
.
dtype
==
torch
.
bfloat16
:
if
quant_type
==
'fp4'
:
lib
.
cquantize_blockwise_bf16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cquantize_blockwise_bf16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
post_call
(
A
.
device
)
datatype
=
get_4bit_type
(
quant_type
,
device
=
A
.
device
)
if
compress_statistics
:
if
compress_statistics
:
offset
=
absmax
.
mean
()
offset
=
absmax
.
mean
()
absmax
-=
offset
absmax
-=
offset
#code = create_custom_map().to(absmax.device)
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
256
)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
256
)
del
absmax
del
absmax
state
=
[
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
[
offset
,
state2
],
quant_type
]
state
=
[
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
[
offset
,
state2
],
quant_type
,
datatype
]
else
:
else
:
state
=
[
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
,
quant_type
]
state
=
[
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
,
quant_type
,
datatype
]
return
out
,
state
return
out
,
state
...
@@ -834,7 +882,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
...
@@ -834,7 +882,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
shape
=
out
.
shape
shape
=
out
.
shape
dtype
=
out
.
dtype
dtype
=
out
.
dtype
else
:
else
:
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
=
quant_state
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
,
data_type
=
quant_state
if
compressed_stats
is
not
None
:
if
compressed_stats
is
not
None
:
...
@@ -860,6 +908,11 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
...
@@ -860,6 +908,11 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
lib
.
cdequantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
lib
.
cdequantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
else
:
lib
.
cdequantize_blockwise_fp16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
lib
.
cdequantize_blockwise_fp16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
elif
out
.
dtype
==
torch
.
bfloat16
:
if
quant_type
==
'fp4'
:
lib
.
cdequantize_blockwise_bf16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cdequantize_blockwise_bf16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
post_call
(
A
.
device
)
...
@@ -1398,7 +1451,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
...
@@ -1398,7 +1451,7 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
return
sout
return
sout
def
cutlass3_gemm
(
def
gemv_4bit
(
A
:
Tensor
,
A
:
Tensor
,
B
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
out
:
Tensor
=
None
,
...
@@ -1406,95 +1459,35 @@ def cutlass3_gemm(
...
@@ -1406,95 +1459,35 @@ def cutlass3_gemm(
transposed_B
=
False
,
transposed_B
=
False
,
state
=
None
state
=
None
):
):
prev_device
=
pre_call
(
A
.
device
)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
state
is
None
:
if
state
is
None
:
Bshape
=
B
.
shape
raise
ValueError
(
f
'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )'
)
bout
=
Bshape
[
1
]
else
:
Bshape
=
state
[
1
]
bout
=
Bshape
[
0
]
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
sA
=
A
.
shape
if
A
.
numel
()
!=
A
.
shape
[
-
1
]:
sB
=
B
.
shape
raise
ValueError
(
f
'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]'
)
if
transposed_A
and
len
(
sA
)
==
2
:
sA
=
(
sA
[
1
],
sA
[
0
])
elif
transposed_A
and
len
(
sA
)
==
3
:
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
Bshape
=
state
[
1
]
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
bout
=
Bshape
[
0
]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
,
data_type
=
state
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if
compressed_stats
is
not
None
:
if
len
(
sB
)
==
2
:
offset
,
state2
=
compressed_stats
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
absmax
=
dequantize_blockwise
(
absmax
,
state2
)
transposed_B
=
False
absmax
+=
offset
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
if
len
(
A
.
shape
)
==
2
:
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
transposed_A
=
False
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
else
:
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
transposed_A
=
False
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
transposed_A
=
True
if
len
(
sA
)
==
2
:
n
=
sA
[
0
]
ldb
=
A
.
stride
()[
1
if
transposed_A
else
0
]
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
n
=
sA
[
0
]
*
sA
[
1
]
ldb
=
sA
[
2
]
m
=
sB
[
1
]
k
=
sB
[
0
]
lda
=
B
.
stride
()[
0
]
ldc
=
sB
[
1
]
elif
len
(
sB
)
==
3
:
# special case
assert
len
(
sA
)
==
3
if
not
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
raise
ValueError
(
f
"Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
"
)
transposed_A
=
True
transposed_B
=
False
m
=
sB
[
2
]
n
=
sA
[
2
]
k
=
sB
[
0
]
*
sB
[
1
]
lda
=
n
ldb
=
sA
[
2
]
ldc
=
m
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
# B^T @ A^T = C^T
if
out
is
None
:
# [km, nk -> mn]
if
len
(
A
.
shape
)
==
3
:
#lda = ldb = ldc = 1
out
=
torch
.
empty
(
size
=
(
A
.
shape
[
0
],
A
.
shape
[
1
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
#lda = 1
else
:
if
state
is
not
None
:
out
=
torch
.
empty
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
m
=
Bshape
[
0
]
k
=
Bshape
[
1
]
n
=
1
lda
=
Bshape
[
0
]
m
=
Bshape
[
0
]
ldc
=
Bshape
[
0
]
k
=
Bshape
[
1
]
ldb
=
(
ldb
+
1
)
//
2
lda
=
Bshape
[
0
]
#print(m, n, k, lda, ldb, ldc)
ldc
=
Bshape
[
0
]
is_on_gpu
([
B
,
A
,
out
])
ldb
=
(
A
.
shape
[
-
1
]
+
1
)
//
2
is_on_gpu
([
B
,
A
,
out
,
absmax
,
state
[
-
1
]])
m
=
ct
.
c_int32
(
m
)
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
k
=
ct
.
c_int32
(
k
)
...
@@ -1503,16 +1496,20 @@ def cutlass3_gemm(
...
@@ -1503,16 +1496,20 @@ def cutlass3_gemm(
ldc
=
ct
.
c_int32
(
ldc
)
ldc
=
ct
.
c_int32
(
ldc
)
if
B
.
dtype
==
torch
.
uint8
:
if
B
.
dtype
==
torch
.
uint8
:
lib
.
cgemm_4bit_inference
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
if
A
.
dtype
==
torch
.
float16
:
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_host_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
elif
A
.
dtype
==
torch
.
bfloat16
:
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_host_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
out
),
lda
,
ldb
,
ldc
)
elif
A
.
dtype
==
torch
.
float32
:
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
else
:
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
return
out
post_call
(
prev_device
)
return
out
...
...
bitsandbytes/nn/modules.py
View file @
5f492d43
...
@@ -190,9 +190,9 @@ class Params4bit(torch.nn.Parameter):
...
@@ -190,9 +190,9 @@ class Params4bit(torch.nn.Parameter):
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
# for 8-bit
s
[
-
2
][
0
]
=
s
[
-
2
][
0
].
to
(
device
)
# offset
s
[
-
3
][
0
]
=
s
[
-
3
][
0
].
to
(
device
)
# offset
s
[
-
2
][
1
][
0
]
=
s
[
-
2
][
1
][
0
].
to
(
device
)
# nested quantiation state statitics
s
[
-
3
][
1
][
0
]
=
s
[
-
3
][
1
][
0
].
to
(
device
)
# nested quantiation state statitics
s
[
-
2
][
1
][
1
]
=
s
[
-
2
][
1
][
1
].
to
(
device
)
# nested quantiation codebook
s
[
-
3
][
1
][
1
]
=
s
[
-
3
][
1
][
1
].
to
(
device
)
# nested quantiation codebook
new_param
=
Params4bit
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
new_param
=
Params4bit
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
,
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
,
...
...
csrc/kernels.cu
View file @
5f492d43
...
@@ -3088,7 +3088,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
...
@@ -3088,7 +3088,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
}
}
}
}
#define WARPS
5
#define WARPS
3
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
{
...
@@ -3297,33 +3297,58 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3297,33 +3297,58 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
#endif
#endif
}
}
template
<
typename
T
>
__device__
void
printnonzero
(
T
*
A
,
int
num_values
,
const
char
*
strval
)
{
for
(
int
i
=
0
;
i
<
num_values
;
i
++
)
if
((
float
)
A
[
i
]
!=
0.0
)
printf
(
"%s %i %f
\n
"
,
strval
,
i
,
(
float
)
A
[
i
]);
}
template
__device__
void
printnonzero
<
float
>(
float
*
A
,
int
num_values
,
const
char
*
strval
);
template
__device__
void
printnonzero
<
half
>(
half
*
A
,
int
num_values
,
const
char
*
strval
);
__device__
static
float
nf4_data
[
16
]
=
{
-
1.0
,
-
0.6961928009986877
,
-
0.5250730514526367
,
-
0.39491748809814453
,
-
0.28444138169288635
,
-
0.18477343022823334
,
-
0.09105003625154495
,
0.0
,
0.07958029955625534
,
0.16093020141124725
,
0.24611230194568634
,
0.33791524171829224
,
0.44070982933044434
,
0.5626170039176941
,
0.7229568362236023
,
1.0
};
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
{
#if __CUDA_ARCH__ >= 750
using
namespace
nvcuda
;
using
namespace
nvcuda
;
int
col_offset
=
blockIdx
.
x
*
32
;
int
col_offset
=
blockIdx
.
x
*
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_idx
=
threadIdx
.
x
%
32
;
const
int
half_warp_id
=
threadIdx
.
x
/
16
;
const
int
half_warp_id
=
threadIdx
.
x
/
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
T
quant_map
[
16
];
#pragma unroll 16
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
//__shared__ T quant_map[16*160];
T
local_A
[
2
];
T
local_A
[
2
];
T
local_B
[
64
];
T
local_B
[
64
];
unsigned
char
local_B_4bit
[
32
];
unsigned
char
local_B_4bit
[
32
];
const
int
a_tile_offset
=
16
;
const
int
a_tile_offset
=
16
;
const
int
b_tile_offset
=
(
16
*
32
+
16
);
const
int
b_tile_offset
=
(
16
*
32
+
16
);
__shared__
T
smem_A
[
8
*
16
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_A
[
8
*
16
+
(
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_B
[
2
*
batch_size_warps
*
16
*
32
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_B
[
2
*
batch_size_warps
*
16
*
32
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
//
__shared__ T smem_C[8*32];
__shared__
T
smem_C
[
8
*
32
];
wmma
::
fragment
<
wmma
::
matrix_a
,
8
,
32
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_a
,
8
,
32
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
8
,
32
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
8
,
32
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
8
,
32
,
16
,
half
>
c_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
8
,
32
,
16
,
half
>
c_frag
;
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
for
(
int
i
=
threadIdx
.
x
;
i
<
(
8
*
32
);
i
+=
blockDim
.
x
)
smem_C
[
i
]
=
0.0
f
;
__syncthreads
();
int
ticktock
=
0
;
int
ticktock
=
0
;
int
idx
=
0
+
threadIdx
.
x
;
int
idx
=
0
+
threadIdx
.
x
;
int
loaded_values
=
0
;
int
loaded_values
=
0
;
...
@@ -3349,8 +3374,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3349,8 +3374,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 64
#pragma unroll 64
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
{
{
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
1.0
f
);
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
1.0
f
);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B
[
col
]
=
quant_map
[
160
*
(
local_B_4bit
[
col
/
2
]
>>
4
)
+
warp_idx
]
*
T
(
17.0
);
local_B
[
col
+
1
]
=
quant_map
[
160
*
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
+
warp_idx
]
*
T
(
17.0
);
}
}
}
}
...
@@ -3374,13 +3408,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3374,13 +3408,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
0.0
f
;
}
}
ticktock
=
ticktock
==
0
?
1
:
0
;
ticktock
=
ticktock
==
0
?
1
:
0
;
//if(threadIdx.x == 0)
//printf("aa %i %i\n", idx, loaded_values);
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for
(
int
base_idx
=
blockDim
.
x
-
32
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
for
(
int
base_idx
=
blockDim
.
x
-
32
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
-
32
)
{
{
idx
=
base_idx
+
threadIdx
.
x
;
idx
=
base_idx
+
threadIdx
.
x
;
//if(threadIdx.x == 0)
//printf("%i %i\n", idx, loaded_values);
__syncthreads
();
//
__syncthreads();
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
if
(
idx
<
K
&&
warp_id
<
(
WARPS
-
1
))
{
{
if
(
loaded_values
==
0
)
if
(
loaded_values
==
0
)
...
@@ -3408,9 +3446,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3408,9 +3446,17 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 64
#pragma unroll 64
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
for
(
int
col
=
0
;
col
<
64
;
col
+=
2
)
{
{
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
absidx
);
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
absidx
);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B
[
col
]
=
quant_map
[(
local_B_4bit
[
col
/
2
]
>>
4
)]
*
T
(
absidx
);
local_B
[
col
+
1
]
=
quant_map
[(
local_B_4bit
[
col
/
2
]
&
0x0F
)]
*
T
(
absidx
);
}
}
//printnonzero<T>(local_B, 128, "");
}
}
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
...
@@ -3444,6 +3490,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3444,6 +3490,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
}
}
__syncthreads
();
__syncthreads
();
//if(threadIdx.x == 0)
//{
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
//}
if
(
warp_id
!=
(
WARPS
-
1
)){
return
;
}
if
(
warp_id
!=
(
WARPS
-
1
)){
return
;
}
// only warp_id == (WARPS-1) from here
// only warp_id == (WARPS-1) from here
int
warp_lane
=
threadIdx
.
x
%
32
;
int
warp_lane
=
threadIdx
.
x
%
32
;
...
@@ -3451,6 +3502,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3451,6 +3502,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
ticktock
=
ticktock
==
0
?
1
:
0
;
ticktock
=
ticktock
==
0
?
1
:
0
;
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
for
(
int
k
=
0
;
k
<
batch_size_warps
;
k
++
)
{
{
//if(warp_lane == 0)
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[(
ticktock
*
batch_size_warps
+
k
)
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[(
ticktock
*
batch_size_warps
+
k
)
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[(
ticktock
*
batch_size_warps
+
k
)
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[(
ticktock
*
batch_size_warps
+
k
)
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
...
@@ -3458,13 +3511,116 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3458,13 +3511,116 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
// 129 mu
// 129 mu
if
(
warp_id
==
(
WARPS
-
1
))
if
(
warp_id
==
(
WARPS
-
1
))
wmma
::
store_matrix_sync
(
&
(
smem_A
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
wmma
::
store_matrix_sync
(
&
(
smem_C
[
0
]),
c_frag
,
32
,
wmma
::
mem_row_major
);
//printnonzero<T>(smem_C, 32, "");
if
(
col_offset
+
warp_lane
<
M
)
if
(
col_offset
+
warp_lane
<
M
)
out
[
col_offset
+
warp_lane
]
=
smem_A
[
warp_lane
];
out
[
col_offset
+
warp_lane
]
=
smem_C
[
warp_lane
];
#endif
}
}
#define num_values_4bit 32
template
<
typename
T
,
int
THREADS
,
int
BITS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
// per threadblock:
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
typedef
cub
::
WarpReduce
<
float
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
const
int
warp_idx
=
threadIdx
.
x
/
32
;
const
int
warp_lane
=
threadIdx
.
x
%
32
;
const
int
row_B
=
(
THREADS
/
32
)
*
blockIdx
.
x
+
warp_idx
;
const
int
num_values_8bit
=
num_values_4bit
/
2
;
float
local_C
=
0.0
f
;
unsigned
char
local_B_4bit
[
num_values_8bit
];
T
local_B
[
num_values_4bit
];
T
local_A
[
num_values_4bit
];
__shared__
T
quant_map
[
16
];
T
local_absmax
=
T
(
0.0
f
);
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
datatype
[
i
];
__syncthreads
();
// A: [1, K]
// B: [N, K]
for
(
int
inner_idx
=
warp_lane
*
num_values_4bit
;
inner_idx
<
K
;
inner_idx
+=
32
*
num_values_4bit
)
{
int
inner_idx_halved
=
inner_idx
/
2
;
int
offset_B
=
ldb
*
row_B
;
int
absidx
=
((
2
*
offset_B
)
+
inner_idx
)
/
blocksize
;
local_absmax
=
__ldg
(
&
(
absmax
[
absidx
]));
if
(
row_B
<
M
)
{
if
((
inner_idx_halved
+
num_values_8bit
)
<
K
)
{
reinterpret_cast
<
int4
(
&
)[
num_values_8bit
]
>
(
local_B_4bit
)[
0
]
=
reinterpret_cast
<
int4
*>
(
B
)[(
offset_B
+
(
inner_idx_halved
))
/
(
num_values_8bit
)];
}
else
{
#pragma unroll
for
(
int
j
=
0
;
j
<
(
num_values_8bit
);
j
++
)
if
((
inner_idx_halved
)
+
j
<
K
)
local_B_4bit
[
j
]
=
B
[
offset_B
+
inner_idx_halved
+
j
];
else
local_B_4bit
[
j
]
=
0b01110111
;
}
}
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
}
if
(
inner_idx
+
num_values_4bit
)
{
if
(
BITS
==
16
)
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
3
];
}
else
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
1
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
1
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
2
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
2
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
3
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
3
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
4
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
4
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
5
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
5
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
6
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
6
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
7
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
8
)
+
7
];
}
}
else
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_A
[
k
]
=
A
[
inner_idx
+
k
];
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
local_C
+=
(
float
)(
local_A
[
k
]
*
local_B
[
k
]);
}
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
if
(
row_B
<
M
&&
warp_lane
==
0
)
out
[
row_B
]
=
T
(
local_C
);
}
//#define ROWS 2
//#define ROWS 2
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
//{
//{
...
@@ -3627,8 +3783,14 @@ template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * _
...
@@ -3627,8 +3783,14 @@ template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * _
template
__global__
void
gemm_device
<
half
,
16
,
64
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
64
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
kgemm_4bit_inference
<
half
,
96
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
160
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference
<
half
,
256
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
half
,
128
,
16
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
__nv_bfloat16
,
128
,
16
>(
int
M
,
int
N
,
int
K
,
__nv_bfloat16
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kgemm_4bit_inference_naive
<
float
,
128
,
32
>(
int
M
,
int
N
,
int
K
,
float
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
...
@@ -3784,6 +3946,20 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
...
@@ -3784,6 +3946,20 @@ MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
1
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
1
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
General8bit
)
...
@@ -3792,13 +3968,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
...
@@ -3792,13 +3968,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
FP4
)
...
@@ -3806,13 +3975,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
...
@@ -3806,13 +3975,6 @@ MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
NF4
)
...
@@ -3821,12 +3983,38 @@ MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
...
@@ -3821,12 +3983,38 @@ MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
4096
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
4096
,
4
,
1
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
2048
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
1024
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
512
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
512
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
512
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
256
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
__nv_bfloat16
,
64
,
2
,
0
,
NF4
)
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
__nv_bfloat16
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
__nv_bfloat16
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
__nv_bfloat16
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
const
int
blocksize
,
const
int
n
);
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
...
...
csrc/kernels.cuh
View file @
5f492d43
...
@@ -106,6 +106,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
...
@@ -106,6 +106,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPercentileClipping
(
T
*
__restrict__
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPercentileClipping
(
T
*
__restrict__
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
__global__
void
kHistogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
const
int
maxidx1
,
const
int
n
);
__global__
void
kHistogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
const
int
maxidx1
,
const
int
n
);
...
@@ -124,6 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
...
@@ -124,6 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
THREADS
,
int
BITS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
__global__
void
kfunc
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
template
<
typename
T
,
int
FUNC
>
__global__
void
kfunc
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/ops.cu
View file @
5f492d43
...
@@ -723,10 +723,20 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
...
@@ -723,10 +723,20 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//cout << m << endl;
//cout << m << endl;
//cout << n << endl;
//cout << n << endl;
//cout << k << endl;
//cout << k << endl;
kgemm_4bit_inference
<
T
,
160
><<<
num_blocks
,
160
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
kgemm_4bit_inference
<
T
,
96
><<<
num_blocks
,
96
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
//kgemm_4bit_inference<T, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
}
template
<
typename
T
,
int
BITS
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
int
num_blocks
=
(
m
+
3
)
/
4
;
kgemm_4bit_inference_naive
<
T
,
128
,
BITS
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
{
{
int
threads
=
512
;
int
threads
=
512
;
...
@@ -747,6 +757,10 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
...
@@ -747,6 +757,10 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
template
void
func
<
float
,
_MUL
>(
float
*
A
,
float
*
B
,
float
value
,
long
n
);
template
void
func
<
float
,
_MUL
>(
float
*
A
,
float
*
B
,
float
value
,
long
n
);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
half
,
16
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
__nv_bfloat16
,
16
>(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
void
gemm_4bit_inference_naive
<
float
,
32
>(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
@@ -773,19 +787,27 @@ template void estimateQuantiles(half *A, float *code, float offset, int n);
...
@@ -773,19 +787,27 @@ template void estimateQuantiles(half *A, float *code, float offset, int n);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
quantizeBlockwise
<
half
,
1
,
General8bit
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
1
,
General8bit
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
1
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
General8bit
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
General8bit
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
FP4
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
FP4
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
FP4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
NF4
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
NF4
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
1
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
FP4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
NF4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
NF4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
__nv_bfloat16
,
1
,
General8bit
>(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
__nv_bfloat16
,
0
,
General8bit
>(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
__nv_bfloat16
,
0
,
FP4
>(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
__nv_bfloat16
,
0
,
NF4
>(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
__nv_bfloat16
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
__nv_bfloat16
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
__nv_bfloat16
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
);
#define MAKE_optimizer32bit(name, gtype) \
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
...
...
csrc/ops.cuh
View file @
5f492d43
...
@@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
...
@@ -200,6 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
BITS
>
void
gemm_4bit_inference_naive
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
);
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
);
...
...
csrc/pythonInterface.c
View file @
5f492d43
...
@@ -28,6 +28,15 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
...
@@ -28,6 +28,15 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
{
gemm_4bit_inference
<
half
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
half
,
16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
__nv_bfloat16
,
16
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
gemm_4bit_inference_naive_fp32
(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive
<
float
,
32
>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
...
@@ -103,19 +112,29 @@ void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){
...
@@ -103,19 +112,29 @@ void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){
void
percentileClipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
half
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
percentileClipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
half
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
quantizeBlockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16_nf4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16_nf4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_bf16
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
__nv_bfloat16
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_bf16_fp4
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
__nv_bfloat16
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_bf16_nf4
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
__nv_bfloat16
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_nf4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_nf4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp32_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp32_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_bf16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
__nv_bfloat16
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_bf16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
__nv_bfloat16
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_bf16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
__nv_bfloat16
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
...
@@ -174,21 +193,31 @@ extern "C"
...
@@ -174,21 +193,31 @@ extern "C"
void
cestimate_quantiles_fp16
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
){
estimateQuantiles_fp16
(
A
,
code
,
offset
,
n
);
}
void
cestimate_quantiles_fp16
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
){
estimateQuantiles_fp16
(
A
,
code
,
offset
,
n
);
}
void
cquantize
(
float
*
code
,
float
*
A
,
unsigned
char
*
out
,
int
n
){
quantize
(
code
,
A
,
out
,
n
);
}
void
cquantize
(
float
*
code
,
float
*
A
,
unsigned
char
*
out
,
int
n
){
quantize
(
code
,
A
,
out
,
n
);
}
void
cdequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
){
dequantize
(
code
,
A
,
out
,
n
);
}
void
cdequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
){
dequantize
(
code
,
A
,
out
,
n
);
}
void
cquantize_blockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp
32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp
32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp
16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp
16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16_nf4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16_nf4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_nf4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_nf4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_bf16
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_bf16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_bf16_fp4
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_bf16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_bf16_nf4
(
float
*
code
,
__nv_bfloat16
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_bf16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_bf16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_bf16
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_bf16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_bf16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_bf16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
__nv_bfloat16
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_bf16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#define MAKE_CFUNC32(name, gtype, gbits) \
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
...
@@ -368,6 +397,15 @@ extern "C"
...
@@ -368,6 +397,15 @@ extern "C"
CMAKE_ELEMENTWISE_FUNC
(
arange
,
fp32
,
float
,
ARANGE
)
CMAKE_ELEMENTWISE_FUNC
(
arange
,
fp32
,
float
,
ARANGE
)
CMAKE_ELEMENTWISE_FUNC
(
_mul
,
fp32
,
float
,
_MUL
)
CMAKE_ELEMENTWISE_FUNC
(
_mul
,
fp32
,
float
,
_MUL
)
void
cgemm_4bit_inference_naive_fp16
(
int
m
,
int
n
,
int
k
,
half
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive_bf16
(
int
m
,
int
n
,
int
k
,
__nv_bfloat16
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
__nv_bfloat16
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
void
cgemm_4bit_inference_naive_fp32
(
int
m
,
int
n
,
int
k
,
float
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
float
*
datatype
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
gemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
}
#endif
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
long
long
blocksize
,
long
long
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
long
long
blocksize
,
long
long
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
...
...
tests/test_functional.py
View file @
5f492d43
...
@@ -154,34 +154,36 @@ def test_dynamic_quantization():
...
@@ -154,34 +154,36 @@ def test_dynamic_quantization():
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
"fp32"
,
"fp16"
,
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"nested"
,
[
False
,
True
],
ids
=
[
"False"
,
"True"
])
@
pytest
.
mark
.
parametrize
(
"nested"
,
[
False
,
True
],
ids
=
[
"False"
,
"True"
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
def
test_dynamic_blockwise_quantization
(
nested
,
blocksize
):
def
test_dynamic_blockwise_quantization
(
dtype
,
nested
,
blocksize
):
#print('')
#print('')
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
,
dtype
=
dtype
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
.
float
()
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
assert
abserr
<
0.011
assert
abserr
<
0.011
assert
relerr
<
0.018
assert
relerr
<
0.018
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
assert
A2
.
dtype
==
dtype
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs
=
[]
diffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
,
dtype
=
dtype
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
.
float
()
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
.
float
()
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
...
@@ -189,6 +191,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
...
@@ -189,6 +191,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.0035
assert
abserr
<
0.0035
assert
relerr
<
0.015
assert
relerr
<
0.015
assert
A2
.
dtype
==
dtype
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
...
@@ -1781,16 +1784,16 @@ values = []
...
@@ -1781,16 +1784,16 @@ values = []
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 2560, 4*2560))
values
.
append
((
batch_size
,
seqdim
,
4096
,
4
*
4096
))
#
values.append((batch_size, seqdim, 4096, 4*4096))
values
.
append
((
batch_size
,
seqdim
,
5120
,
4
*
5120
))
#
values.append((batch_size, seqdim, 5120, 4*5120))
values
.
append
((
batch_size
,
seqdim
,
6656
,
4
*
6656
))
values
.
append
((
batch_size
,
seqdim
,
6656
,
4
*
6656
))
values
.
append
((
batch_size
,
seqdim
,
8192
,
4
*
8192
))
#
values.append((batch_size, seqdim, 8192, 4*8192))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
#values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
iters
=
8
0
iters
=
100
0
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
...
@@ -1800,7 +1803,8 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1800,7 +1803,8 @@ def test_bench_matmul(batch, seq, model, hidden):
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
B_nf4_c
,
state_nf4_c
=
F
.
quantize_nf4
(
B
,
compress_statistics
=
True
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
).
cuda
().
half
()
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
linear8bit
.
eval
()
...
@@ -1813,6 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1813,6 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
# warmup
# warmup
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
...
@@ -1827,26 +1832,34 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1827,26 +1832,34 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
#torch.cuda.synchronize()
t0
=
time
.
time
()
#t0 = time.time()
for
i
in
range
(
iters
):
#for i in range(iters):
bnb
.
matmul_4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state
)
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
torch
.
cuda
.
synchronize
()
#torch.cuda.synchronize()
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
#torch.cuda.synchronize()
#print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_f
p
4
.
t
(),
quant_state
=
state_
c
)
bnb
.
matmul_4bit
(
A
,
B_
n
f4
.
t
(),
quant_state
=
state_
nf4
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb
fp4 + compressed stats
: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb
nf4
: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
bnb
.
matmul_4bit
(
A
,
B_nf4
_c
.
t
(),
quant_state
=
state_nf4
_c
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb nf4+DQ: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#torch.cuda.synchronize()
#torch.cuda.synchronize()
#t0 = time.time()
#t0 = time.time()
...
@@ -1901,21 +1914,21 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1901,21 +1914,21 @@ def test_bench_matmul(batch, seq, model, hidden):
#torch.cuda.synchronize()
#torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit
(
A
)
#
linear8bit(A)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
t0
=
time
.
time
()
#
t0 = time.time()
for
i
in
range
(
iters
):
#
for i in range(iters):
linear8bit
(
A
)
#
linear8bit(A)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
print
(
f
"bnb linear8bitlt (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linearMixedBit
(
A
)
#
linearMixedBit(A)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
t0
=
time
.
time
()
#
t0 = time.time()
for
i
in
range
(
iters
):
#
for i in range(iters):
linearMixedBit
(
A
)
#
linearMixedBit(A)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
print
(
f
"bnb linear8bitlt with threshold (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train(A)
#linear8bit_train(A)
#torch.cuda.synchronize()
#torch.cuda.synchronize()
...
@@ -2221,7 +2234,8 @@ def test_bench_dequantization():
...
@@ -2221,7 +2234,8 @@ def test_bench_dequantization():
def
test_fp4_quant
():
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
"fp32"
,
"fp16"
,
"bf16"
])
def
test_fp4_quant
(
dtype
):
vals
=
list
(
product
([
0
,
1
],
repeat
=
4
))
vals
=
list
(
product
([
0
,
1
],
repeat
=
4
))
code
=
{}
code
=
{}
...
@@ -2243,7 +2257,7 @@ def test_fp4_quant():
...
@@ -2243,7 +2257,7 @@ def test_fp4_quant():
result
=
sign
*
exp
*
frac
result
=
sign
*
exp
*
frac
code
[
idx
]
=
result
code
[
idx
]
=
result
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
(
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
,
dtype
=
dtype
)
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
...
@@ -2252,7 +2266,7 @@ def test_fp4_quant():
...
@@ -2252,7 +2266,7 @@ def test_fp4_quant():
idx
=
err
>
1.0
idx
=
err
>
1.0
err
=
err
.
mean
()
err
=
err
.
mean
()
assert
A2
.
dtype
==
dtype
assert
err
.
item
()
<
0.1
assert
err
.
item
()
<
0.1
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
...
@@ -2297,7 +2311,8 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2297,7 +2311,8 @@ def test_4bit_compressed_stats(quant_type):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'nf4'
])
def
test_bench_4bit_dequant
(
quant_type
):
def
test_bench_4bit_dequant
(
quant_type
):
blocksize
=
256
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
...
@@ -2311,7 +2326,7 @@ def test_bench_4bit_dequant(quant_type):
...
@@ -2311,7 +2326,7 @@ def test_bench_4bit_dequant(quant_type):
#print(max_theoretical_s*1e6)
#print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
iters
=
5
iters
=
100
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
...
@@ -2344,139 +2359,88 @@ def test_normal_map_tree():
...
@@ -2344,139 +2359,88 @@ def test_normal_map_tree():
print
(
pivots
)
print
(
pivots
)
#
@pytest.mark.parametrize("d
type", [torch.float32, torch.float16], ids=['fp32', 'fp16
'])
@
pytest
.
mark
.
parametrize
(
"d
ouble_quant"
,
[
True
,
False
],
ids
=
[
'DQ_True'
,
'DQ_False
'
])
@
pytest
.
mark
.
parametrize
(
"
dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16
'
])
@
pytest
.
mark
.
parametrize
(
"
storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4
'
])
def
test_cutlass3_gemm
(
dtype
):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
debug
=
True
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
):
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
print
(
''
)
#
for dim in [
409
6, 512
0
,
6656, 8192
]:
for
dim
in
[
128
,
25
6
,
512
,
1024
,
2048
,
4096
]:
for
dim
in
[
4
096
]:
#
for dim in [4
*1024
]:
#for dim in [1
28+1
]:
#for dim in [1
*16
]:
errs
=
[]
errs
=
[]
relerrs
=
[]
relerrs
=
[]
max_err
=
0
max_err
=
0
max_relerr
=
0
max_relerr
=
0
for
i
in
range
(
100
):
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
for
i
in
range
(
100
):
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err
=
torch
.
abs
(
C1
-
C2
)
mag
=
torch
.
abs
(
C1
)
+
1e-8
relerr
=
err
/
mag
max_err
=
max
(
err
.
max
(),
max_err
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print(i, err, relerr)
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
# print(out)
# print(out[:-1].sum())
# print('='*80)
# print(C1.flatten()[-6:])
# print(C2.flatten()[-6:])
# #assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
not
debug
)
#print(c/math.sqrt(dim))
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_gemm_4bit
(
dtype
):
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for
dim
in
[
4096
]:
errs
=
[]
relerrs
=
[]
max_err
=
0
max_relerr
=
0
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
4
*
dim
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print('')
#print(A)
#print(A)
#print(B.t())
#print(B.t())
#A[:, :-1] = 0
#A[:, :-1] = 0
#B[:, :-1] = 0
#B[:, :-1] = 0
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
qB
,
state
=
F
.
quantize_
nf4
(
B
)
qB
,
state
=
F
.
quantize_
4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
F
.
dequantize_
nf4
(
qB
,
state
)
#
F.dequantize_
4bit
(qB, state)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
print
(
C1
.
shape
,
C2
.
shape
)
#print(state)
#print(qB)
#print('')
#print(A)
#print(B)
#print('='*89)
#print(C3)
#print(C1.shape, C2.shape)
# tensor cores are non-deterministic
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# so we need to analyze errors around the mean
# to test our implementation
# to test our implementation
err
=
torch
.
abs
(
C1
-
C2
)
err
=
torch
.
abs
(
C1
-
C2
)
.
float
()
mag
=
torch
.
abs
(
C1
)
+
1e-
8
mag
=
torch
.
abs
(
C1
)
.
float
()
+
1e-
5
relerr
=
err
/
mag
relerr
=
err
/
mag
max_err
=
max
(
err
.
max
(),
max_err
)
max_err
=
max
(
err
.
max
(),
max_err
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
#print(err)
errs
.
append
(
err
)
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
relerrs
.
append
(
relerr
)
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
print
(
''
)
print
(
i
,
err
,
relerr
)
print
(
A
.
flatten
()[
-
6
:])
print
(
B
.
flatten
()[
-
6
:])
out
=
A
.
flatten
()[
-
6
:]
*
B
.
flatten
()[
-
6
:]
print
(
out
)
print
(
out
[:
-
1
].
sum
())
print
(
'='
*
80
)
print
(
C1
.
flatten
()[
-
6
:])
print
(
C2
.
flatten
()[
-
6
:])
#assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
#print(c/math.sqrt(dim))
#print('')
print
(
''
)
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
#print(dim, (max_err.item(), max_relerr.item()))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
print
(
C1
.
flatten
()[
-
20
:])
print
(
C2
.
flatten
()[
-
20
:])
print
(
C3
.
flatten
()[
-
20
:])
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
dim
)
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
dim
)
if
dtype
==
torch
.
float16
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
else
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
3e-4
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.003
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
def
test_managed
():
...
...
tests/test_modules.py
View file @
5f492d43
...
@@ -535,6 +535,7 @@ def test_kbit_backprop(module):
...
@@ -535,6 +535,7 @@ def test_kbit_backprop(module):
kbit
[
1
].
bias
.
detach
().
copy_
(
ref
[
1
].
bias
)
kbit
[
1
].
bias
.
detach
().
copy_
(
ref
[
1
].
bias
)
ref
=
ref
.
half
().
cuda
()
ref
=
ref
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
kbit
=
kbit
.
half
().
to
(
'cuda'
)
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment