Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
64cc0592
Commit
64cc0592
authored
Apr 02, 2023
by
Tim Dettmers
Browse files
First draft of NF4.
parent
4ad999d1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
292 additions
and
145 deletions
+292
-145
bitsandbytes/functional.py
bitsandbytes/functional.py
+38
-6
csrc/kernels.cu
csrc/kernels.cu
+184
-91
csrc/kernels.cuh
csrc/kernels.cuh
+2
-2
csrc/ops.cu
csrc/ops.cu
+26
-24
csrc/ops.cuh
csrc/ops.cuh
+9
-2
csrc/pythonInterface.c
csrc/pythonInterface.c
+19
-11
tests/test_functional.py
tests/test_functional.py
+14
-9
No files found.
bitsandbytes/functional.py
View file @
64cc0592
...
...
@@ -688,8 +688,13 @@ def dequantize_blockwise(
return
out
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit_packed
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
)
->
Tensor
:
def
quantize_nf4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit_packed
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'nf4'
)
def
quantize_4bit_packed
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
,
quant_type
=
'fp4'
)
->
Tensor
:
"""
Quantize tensor A in blocks of FP4 values.
...
...
@@ -705,6 +710,8 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
The output tensor (8-bit).
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
...
...
@@ -715,6 +722,8 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
"""
if
A
.
device
.
type
!=
'cuda'
:
raise
NotImplementedError
(
f
'Device type not supported for FP4 quantization:
{
A
.
device
.
type
}
'
)
if
quant_type
not
in
[
'fp4'
,
'nf4'
]:
raise
NotImplementedError
(
f
'4-bit quantization data type
{
quant_type
}
is not implemented.'
)
n
=
A
.
numel
()
input_shape
=
A
.
shape
...
...
@@ -734,9 +743,15 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
is_on_gpu
([
A
,
out
,
absmax
])
if
A
.
dtype
==
torch
.
float32
:
lib
.
cquantize_blockwise_fp32_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
if
quant_type
==
'fp4'
:
lib
.
cquantize_blockwise_fp32_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cquantize_blockwise_fp32_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
elif
A
.
dtype
==
torch
.
float16
:
lib
.
cquantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
if
quant_type
==
'fp4'
:
lib
.
cquantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cquantize_blockwise_fp16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int32
(
blocksize
),
ct
.
c_int
(
n
))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
...
...
@@ -754,8 +769,13 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
return
out
,
state
def
dequantize_fp4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
return
dequantize_4bit_packed
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'fp4'
)
def
dequantize_nf4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
return
dequantize_4bit_packed
(
A
,
quant_state
,
absmax
,
out
,
blocksize
,
'nf4'
)
def
dequantize_
fp4
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
)
->
Tensor
:
def
dequantize_
4bit_packed
(
A
:
Tensor
,
quant_state
:
Tuple
[
Tensor
,
Tensor
]
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
:
int
=
64
,
quant_type
=
'fp4'
)
->
Tensor
:
"""
Dequantizes FP4 blockwise quantized values.
...
...
@@ -771,6 +791,10 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
...
...
@@ -780,6 +804,8 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
"""
if
blocksize
not
in
[
2048
,
4096
,
1024
,
512
,
256
,
128
,
64
]:
raise
ValueError
(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]"
)
if
quant_type
not
in
[
'fp4'
,
'nf4'
]:
raise
NotImplementedError
(
f
'4-bit quantization data type
{
quant_type
}
is not implemented.'
)
if
quant_state
is
None
:
assert
absmax
is
not
None
and
out
is
not
None
...
...
@@ -802,9 +828,15 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
device
=
pre_call
(
A
.
device
)
is_on_gpu
([
A
,
absmax
,
out
])
if
out
.
dtype
==
torch
.
float32
:
lib
.
cdequantize_blockwise_fp32_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
if
quant_type
==
'fp4'
:
lib
.
cdequantize_blockwise_fp32_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cdequantize_blockwise_fp32_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
elif
out
.
dtype
==
torch
.
float16
:
lib
.
cdequantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
if
quant_type
==
'fp4'
:
lib
.
cdequantize_blockwise_fp16_fp4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
lib
.
cdequantize_blockwise_fp16_nf4
(
get_ptr
(
None
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_int
(
blocksize
),
ct
.
c_int
(
n
))
else
:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
...
...
csrc/kernels.cu
View file @
64cc0592
...
...
@@ -140,44 +140,111 @@ __device__ unsigned char dQuantizeFP4(float x)
return
0b0000
+
sign
;
}
__device__
float
dDequantizeNF4
(
unsigned
char
val
,
float
absmax
)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if
((
val
&
0b1000
)
==
8
)
if
((
val
&
0b0100
)
==
4
)
// 1
if
((
val
&
0b0010
)
==
2
)
// 11
if
((
val
&
0b0001
)
==
1
)
// 111
return
1.0
f
*
absmax
;
else
return
0.7229568362236023
f
*
absmax
;
else
if
((
val
&
0b0001
)
==
1
)
// 110
return
0.5626170039176941
f
*
absmax
;
else
return
0.44070982933044434
f
*
absmax
;
else
if
((
val
&
0b0010
)
==
2
)
//10
if
((
val
&
0b0001
)
==
1
)
// 101
return
0.33791524171829224
f
*
absmax
;
else
return
0.24611230194568634
f
*
absmax
;
else
if
((
val
&
0b0001
)
==
1
)
// 100
return
0.16093020141124725
f
*
absmax
;
else
return
0.07958029955625534
f
*
absmax
;
else
if
((
val
&
0b0100
)
==
4
)
// 0
if
((
val
&
0b0010
)
==
2
)
//01
if
((
val
&
0b0001
)
==
1
)
// 011
return
0.0
f
*
absmax
;
else
return
-
0.09105003625154495
f
*
absmax
;
else
if
((
val
&
0b0001
)
==
1
)
// 010
return
-
0.18477343022823334
f
*
absmax
;
else
return
-
0.28444138169288635
f
*
absmax
;
else
if
((
val
&
0b0010
)
==
2
)
//00
if
((
val
&
0b0001
)
==
1
)
// 001
return
-
0.39491748809814453
f
*
absmax
;
else
return
-
0.5250730514526367
f
*
absmax
;
else
if
((
val
&
0b0001
)
==
1
)
// 000
return
-
0.6961928009986877
f
*
absmax
;
else
return
-
1.0
f
*
absmax
;
}
__device__
unsigned
char
dQuantizeNormal
(
float
x
)
{
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12
int
sign
=
x
<
0
?
0b1000
:
0b0000
;
x
=
fabsf
(
x
);
if
(
x
>
3.5
f
)
if
(
x
>
7.0
f
)
if
(
x
>
10.0
f
)
return
0b0011
+
sign
;
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if
(
x
>
0.03979014977812767
f
)
if
(
x
>
0.3893125355243683
f
)
// 1
if
(
x
>
0.6427869200706482
f
)
// 11
if
(
x
>
0.8614784181118011
f
)
// 111
return
0b1111
;
else
return
0b1110
;
else
return
0b0010
+
sign
;
if
(
x
>
0.5016634166240692
f
)
// 110
return
0b1101
;
else
return
0b1100
;
else
if
(
x
>
5.0
f
)
return
0b101
+
sign
;
if
(
x
>
0.2035212516784668
f
)
// 10
if
(
x
>
0.2920137718319893
f
)
// 101
return
0b1011
;
else
return
0b1010
;
else
return
0b100
+
sign
;
if
(
x
>
0.1202552504837513
f
)
// 100
return
0b1001
;
else
return
0b1100
;
else
if
(
x
>
1.03125
f
)
if
(
x
>
2.5
f
)
return
0b0111
+
sign
;
if
(
x
>
-
0.33967943489551544
f
)
// 0
if
(
x
>
-
0.13791173323988914
f
)
// 01
if
(
x
>
-
0.045525018125772476
f
)
// 011
return
0b0111
;
else
return
0b0110
;
else
return
0b0110
+
sign
;
if
(
x
>
-
0.23460740596055984
f
)
// 010
return
0b0101
;
else
return
0b0100
;
else
if
(
x
>
0.03125
f
)
return
0b0001
+
sign
;
if
(
x
>
-
0.6106329262256622
f
)
// 00
if
(
x
>
-
0.4599952697753906
f
)
// 001
return
0b0011
;
else
return
0b0010
;
else
return
0b0000
+
sign
;
if
(
x
>
-
0.8480964004993439
f
)
// 000
return
0b0001
;
else
return
0b0000
;
}
template
<
int
STOCHASTIC
>
...
...
@@ -564,7 +631,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
}
}
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_PER_TH
,
int
STOCHASTIC
,
int
FP4
>
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_PER_TH
,
int
STOCHASTIC
,
int
DATA_TYPE
>
//__launch_bounds__(TH, 4)
__global__
void
kQuantizeBlockwise
(
float
*
code
,
T
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
)
{
...
...
@@ -574,13 +641,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
T
vals
[
NUM_PER_TH
];
float
rand_vals
[
NUM_PER_TH
];
unsigned
char
qvals
[
FP4
?
NUM_PER_TH
/
2
:
NUM_PER_TH
];
unsigned
char
qvals
[
(
DATA_TYPE
>
0
)
?
NUM_PER_TH
/
2
:
NUM_PER_TH
];
//float local_abs_max = -FLT_MAX;
float
local_abs_max
=
0.0
f
;
int
local_rand_idx
=
0
;
typedef
cub
::
BlockLoad
<
T
,
BLOCK_SIZE
/
NUM_PER_TH
,
NUM_PER_TH
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadT
;
typedef
cub
::
BlockStore
<
unsigned
char
,
BLOCK_SIZE
/
NUM_PER_TH
,
FP4
?
NUM_PER_TH
/
2
:
NUM_PER_TH
,
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
StoreChar
;
typedef
cub
::
BlockStore
<
unsigned
char
,
BLOCK_SIZE
/
NUM_PER_TH
,
(
DATA_TYPE
>
0
)
?
NUM_PER_TH
/
2
:
NUM_PER_TH
,
cub
::
BLOCK_STORE_WARP_TRANSPOSE
>
StoreChar
;
typedef
cub
::
BlockReduce
<
float
,
BLOCK_SIZE
/
NUM_PER_TH
>
BlockReduce
;
typedef
cub
::
BlockLoad
<
float
,
BLOCK_SIZE
/
NUM_PER_TH
,
NUM_PER_TH
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadFloat
;
...
...
@@ -591,7 +658,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
__shared__
float
smem_code
[
256
];
__shared__
float
smem_absmax_value
[
1
];
if
(
!
FP4
)
if
(
DATA_TYPE
==
General8bit
)
for
(
int
i
=
threadIdx
.
x
;
i
<
256
;
i
+=
blockDim
.
x
)
smem_code
[
i
]
=
code
[
i
];
...
...
@@ -633,31 +700,41 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
LoadFloat
(
loadf
).
Load
(
&
rand
[
local_rand_idx
],
rand_vals
,
BLOCK_SIZE
,
0
);
}
if
(
FP4
)
unsigned
char
packed_4bit
=
0
;
switch
(
DATA_TYPE
)
{
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
{
unsigned
char
packed_fp4
=
0
;
packed_fp4
|=
dQuantizeFP4
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
)
<<
4
;
packed_fp4
|=
dQuantizeFP4
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
);
qvals
[
j
]
=
packed_fp4
;
}
}
else
{
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
{
if
(
!
STOCHASTIC
)
qvals
[
j
]
=
dQuantize
<
0
>
(
smem_code
,
0.0
f
,
((
float
)
vals
[
j
])
*
local_abs_max
);
else
qvals
[
j
]
=
dQuantize
<
1
>
(
smem_code
,
rand_vals
[
j
],
((
float
)
vals
[
j
])
*
local_abs_max
);
}
case
General8bit
:
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
{
if
(
!
STOCHASTIC
)
qvals
[
j
]
=
dQuantize
<
0
>
(
smem_code
,
0.0
f
,
((
float
)
vals
[
j
])
*
local_abs_max
);
else
qvals
[
j
]
=
dQuantize
<
1
>
(
smem_code
,
rand_vals
[
j
],
((
float
)
vals
[
j
])
*
local_abs_max
);
}
break
;
case
FP4
:
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
{
packed_4bit
|=
dQuantizeFP4
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
)
<<
4
;
packed_4bit
|=
dQuantizeFP4
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
);
qvals
[
j
]
=
packed_4bit
;
}
break
;
case
NF4
:
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
{
packed_4bit
|=
dQuantizeNormal
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
)
<<
4
;
packed_4bit
|=
dQuantizeNormal
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
);
qvals
[
j
]
=
packed_4bit
;
}
break
;
}
__syncthreads
();
StoreChar
(
storec
).
Store
(
&
(
out
[
FP4
?
i
/
2
:
i
]),
qvals
,
FP4
?
(
valid_items
+
1
)
/
2
:
valid_items
);
StoreChar
(
storec
).
Store
(
&
(
out
[
(
DATA_TYPE
>
0
)
?
i
/
2
:
i
]),
qvals
,
(
DATA_TYPE
>
0
)
?
(
valid_items
+
1
)
/
2
:
valid_items
);
}
}
...
...
@@ -2957,44 +3034,60 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
template
__global__
void
kPercentileClipping
<
float
,
2048
,
4
>(
float
*
__restrict__
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
__global__
void
kPercentileClipping
<
half
,
2048
,
4
>(
half
*
__restrict__
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
4096
,
4
,
0
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
4096
,
4
,
0
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
4096
,
4
,
1
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
4096
,
4
,
1
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
2048
,
4
,
0
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
2048
,
4
,
0
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
1024
,
4
,
0
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
1024
,
4
,
0
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
512
,
2
,
0
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
512
,
2
,
0
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
256
,
2
,
0
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
256
,
2
,
0
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
128
,
2
,
0
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
128
,
2
,
0
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
64
,
2
,
0
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
64
,
2
,
0
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
32
,
1
,
0
,
0
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
32
,
1
,
0
,
0
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
4096
,
4
,
0
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
4096
,
4
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
2048
,
4
,
0
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
2048
,
4
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
1024
,
4
,
0
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
1024
,
4
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
512
,
2
,
0
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
512
,
2
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
256
,
2
,
0
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
256
,
2
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
128
,
2
,
0
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
128
,
2
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
64
,
2
,
0
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
64
,
2
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
1
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
1
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
0
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
0
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
1
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
1
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
512
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
General8bit
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
512
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
FP4
)
MAKE_kQuantizeBlockwise
(
half
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
512
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
256
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
half
,
64
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
4096
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
2048
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
1024
,
4
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
512
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
256
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
128
,
2
,
0
,
NF4
)
MAKE_kQuantizeBlockwise
(
float
,
64
,
2
,
0
,
NF4
)
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
...
...
csrc/kernels.cuh
View file @
64cc0592
...
...
@@ -14,8 +14,8 @@ template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A,
__global__
void
kQuantize
(
float
*
code
,
float
*
__restrict__
const
A
,
unsigned
char
*
out
,
const
int
n
);
__global__
void
kDequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_PER_TH
,
int
STOCHASTIC
,
int
FP4
>
__global__
void
kQuantizeBlockwise
(
float
*
code
,
T
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
THREADS
,
int
NUM_PER_TH
,
int
FP4
>
__global__
void
kDequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
const
int
blocksize
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
NUM_PER_TH
,
int
STOCHASTIC
,
int
DATA_TYPE
>
__global__
void
kQuantizeBlockwise
(
float
*
code
,
T
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
<
typename
T
,
int
BLOCK_SIZE
,
int
THREADS
,
int
NUM_PER_TH
,
int
DATA_TYPE
>
__global__
void
kDequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
const
int
blocksize
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
,
int
BLOCK_SIZE
,
int
NUM_VALS
>
__global__
void
kPreconditionOptimizer32bit2State
(
T
*
g
,
T
*
p
,
...
...
csrc/ops.cu
View file @
64cc0592
...
...
@@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
template
<
typename
T
,
int
STOCHASTIC
,
int
FP4
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
)
template
<
typename
T
,
int
STOCHASTIC
,
int
DATA_TYPE
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
)
{
int
num_blocks
=
n
/
blocksize
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
...
...
@@ -60,34 +60,32 @@ template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * co
if
(
blocksize
==
4096
)
kQuantizeBlockwise
<
T
,
4096
,
4
,
STOCHASTIC
,
0
><<<
num_blocks
,
1024
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
2048
)
kQuantizeBlockwise
<
T
,
2048
,
4
,
0
,
FP4
><<<
num_blocks
,
512
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
2048
,
4
,
0
,
DATA_TYPE
><<<
num_blocks
,
512
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
1024
)
kQuantizeBlockwise
<
T
,
1024
,
4
,
0
,
FP4
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
1024
,
4
,
0
,
DATA_TYPE
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
512
)
kQuantizeBlockwise
<
T
,
512
,
2
,
0
,
FP4
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
512
,
2
,
0
,
DATA_TYPE
><<<
num_blocks
,
256
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
256
)
kQuantizeBlockwise
<
T
,
256
,
2
,
0
,
FP4
><<<
num_blocks
,
128
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
256
,
2
,
0
,
DATA_TYPE
><<<
num_blocks
,
128
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
128
)
kQuantizeBlockwise
<
T
,
128
,
2
,
0
,
FP4
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
128
,
2
,
0
,
DATA_TYPE
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
64
)
kQuantizeBlockwise
<
T
,
64
,
2
,
0
,
FP4
><<<
num_blocks
,
32
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
32
and
FP4
==
0
)
kQuantizeBlockwise
<
T
,
32
,
1
,
0
,
FP4
><<<
num_blocks
,
32
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
kQuantizeBlockwise
<
T
,
64
,
2
,
0
,
DATA_TYPE
><<<
num_blocks
,
32
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
template
<
typename
T
,
int
FP4
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
blocksize
,
const
int
n
)
template
<
typename
T
,
int
DATA_TYPE
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
blocksize
,
const
int
n
)
{
int
num_blocks
=
n
/
blocksize
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
int
tile_size
=
FP4
?
1024
:
512
;
int
tile_size
=
(
DATA_TYPE
>
0
)
?
1024
:
512
;
if
(
FP4
)
kDequantizeBlockwise
<
T
,
512
,
64
,
8
,
FP4
><<<
(
n
+
tile_size
-
1
)
/
tile_size
,
64
>>>
(
code
,
A
,
absmax
,
out
,
blocksize
/
2
,
n
);
if
(
DATA_TYPE
>
0
)
kDequantizeBlockwise
<
T
,
512
,
64
,
8
,
DATA_TYPE
><<<
(
n
+
tile_size
-
1
)
/
tile_size
,
64
>>>
(
code
,
A
,
absmax
,
out
,
blocksize
/
2
,
n
);
else
kDequantizeBlockwise
<
T
,
512
,
64
,
8
,
FP4
><<<
(
n
+
tile_size
-
1
)
/
tile_size
,
64
>>>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
kDequantizeBlockwise
<
T
,
512
,
64
,
8
,
DATA_TYPE
><<<
(
n
+
tile_size
-
1
)
/
tile_size
,
64
>>>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
...
...
@@ -682,16 +680,20 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
template
void
estimateQuantiles
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
0
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
0
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
1
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
1
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
1
,
0
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
1
,
0
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
0
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
0
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
1
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
1
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
1
,
General8bit
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
1
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
General8bit
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
General8bit
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
FP4
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
FP4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
half
,
0
,
NF4
>(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
quantizeBlockwise
<
float
,
0
,
NF4
>(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
General8bit
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
FP4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
half
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
);
template
void
dequantizeBlockwise
<
float
,
NF4
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
);
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
...
...
csrc/ops.cuh
View file @
64cc0592
...
...
@@ -81,6 +81,13 @@ typedef enum Transform_t
COL_AMPERE
=
4
,
}
Transform_t
;
typedef
enum
DataType_t
{
General8bit
=
0
,
FP4
=
1
,
NF4
=
2
,
}
DataType_t
;
class
Context
{
public:
...
...
@@ -128,8 +135,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
void
quantize
(
float
*
code
,
float
*
A
,
unsigned
char
*
out
,
int
n
);
void
dequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
);
template
<
typename
T
,
int
STOCHASTIC
,
int
FP4
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
<
typename
T
,
int
FP4
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
block_size
,
const
int
n
);
template
<
typename
T
,
int
STOCHASTIC
,
int
DATA_TYPE
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
int
blocksize
,
const
int
n
);
template
<
typename
T
,
int
DATA_TYPE
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
block_size
,
const
int
n
);
template
<
typename
T
,
int
OPTIMIZER
>
void
optimizer32bit
(
T
*
g
,
T
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
float
max_unorm
,
float
param_norm
,
...
...
csrc/pythonInterface.c
View file @
64cc0592
...
...
@@ -76,17 +76,21 @@ MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
void
percentileClipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
float
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
percentileClipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
half
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
quantizeBlockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
0
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
0
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_stochastic_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
half
,
1
,
0
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_stochastic_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
float
,
1
,
0
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
1
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
1
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
0
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
0
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
1
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
1
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_stochastic_fp16
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
half
,
1
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_stochastic_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
){
quantizeBlockwise
<
float
,
1
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
4096
,
n
);
}
void
quantizeBlockwise_fp16_fp4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp16_nf4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
half
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
quantizeBlockwise_fp32_nf4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise
<
float
,
0
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
NULL
,
0
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
General8bit
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
FP4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
half
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
\
void
dequantizeBlockwise_fp32_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
,
NF4
>
(
NULL
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
...
...
@@ -157,6 +161,10 @@ extern "C"
void
cquantize_blockwise_fp32_fp4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_fp4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_fp4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp16_nf4
(
float
*
code
,
half
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cquantize_blockwise_fp32_nf4
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
int
blocksize
,
const
int
n
){
quantizeBlockwise_fp32_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp16_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp16_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
cdequantize_blockwise_fp32_nf4
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise_fp32_nf4
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
...
...
tests/test_functional.py
View file @
64cc0592
...
...
@@ -2254,16 +2254,18 @@ def test_fp4_quant():
assert
relerr
.
item
()
<
0.28
def
test_fp4_compressed_stats
():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
def
test_4bit_compressed_stats
(
quant_type
):
for
blocksize
in
[
128
,
64
]:
errs1
=
[]
errs2
=
[]
for
i
in
range
(
10
000
):
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
q2
,
SA2
=
F
.
quantize_
fp4
(
A1
,
blocksize
=
blocksize
)
q3
,
SA3
=
F
.
quantize_
fp4
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
)
A2
=
F
.
dequantize_
fp4
(
q2
,
SA2
)
A3
=
F
.
dequantize_
fp4
(
q3
,
SA3
)
q2
,
SA2
=
F
.
quantize_
4bit_packed
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_
4bit_packed
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_
4bit_packed
(
q2
,
SA2
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_
4bit_packed
(
q3
,
SA3
,
quant_type
=
quant_type
)
err
=
(
A1
-
A2
).
abs
().
float
()
...
...
@@ -2290,10 +2292,12 @@ def test_fp4_compressed_stats():
def
test_bench_fp4_dequant
():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
def
test_bench_fp4_dequant
(
quant_type
):
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_
fp4
(
a
,
blocksize
=
blocksize
)
qa
,
SA
=
F
.
quantize_
4bit_packed
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
input_size
=
a
.
numel
()
/
2
output_size
=
a
.
numel
()
*
2
...
...
@@ -2307,7 +2311,7 @@ def test_bench_fp4_dequant():
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
F
.
dequantize_
fp4
(
qa
,
SA
,
blocksize
=
blocksize
)
F
.
dequantize_
4bit_packed
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
#b.copy_(a)
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/iters*1e6)
...
...
@@ -2325,6 +2329,7 @@ def test_normal_map_tree():
code
=
F
.
create_normal_map
()
values
=
code
[:
8
].
tolist
()
+
code
[
-
8
:].
tolist
()
num_pivots
=
1
print
(
values
)
while
num_pivots
<
16
:
idx
=
list
(
range
(
16
//
num_pivots
//
2
,
16
,
16
//
num_pivots
))
print
(
idx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment