Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2dd5d690
Commit
2dd5d690
authored
Apr 02, 2023
by
Tim Dettmers
Browse files
Generalized FP4 data type.
parent
51a21df7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
66 deletions
+88
-66
csrc/kernels.cu
csrc/kernels.cu
+82
-62
tests/test_functional.py
tests/test_functional.py
+6
-4
No files found.
csrc/kernels.cu
View file @
2dd5d690
...
@@ -64,6 +64,33 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
...
@@ -64,6 +64,33 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
}
}
}
}
__device__
float
dDequantizeFP4Tree
(
unsigned
char
val
,
float
absmax
)
{
float
sign
=
(
val
&
0b1000
)
==
8
?
-
1.0
f
:
1.0
f
;
if
((
val
&
0b0100
)
==
4
)
// 0
if
((
val
&
0b0010
)
==
2
)
//01
if
((
val
&
0b0001
)
==
1
)
// 111
return
0.25000000
f
*
absmax
*
sign
;
// 1111
else
return
0.16666667
f
*
absmax
*
sign
;
// 1110
else
if
((
val
&
0b0001
)
==
1
)
// 110
return
0.50000000
f
*
absmax
*
sign
;
// 1101
else
return
0.33333333
f
*
absmax
*
sign
;
// 1100
else
if
((
val
&
0b0010
)
==
2
)
//10
if
((
val
&
0b0001
)
==
1
)
// 101
return
1.00000000
f
*
absmax
*
sign
;
// 1011
else
return
0.66666667
f
*
absmax
*
sign
;
// 1010
else
if
((
val
&
0b0001
)
==
1
)
// 100
return
5.208333333e-03
f
*
absmax
*
sign
;
// 1001
else
return
0.00000000
f
*
absmax
*
sign
;
// 1000
}
__device__
unsigned
char
dQuantizeFP4
(
float
x
)
__device__
unsigned
char
dQuantizeFP4
(
float
x
)
{
{
// FP4 with bias of 3
// FP4 with bias of 3
...
@@ -78,42 +105,79 @@ __device__ unsigned char dQuantizeFP4(float x)
...
@@ -78,42 +105,79 @@ __device__ unsigned char dQuantizeFP4(float x)
// 0b010 = 8
// 0b010 = 8
// 0b011 = 12
// 0b011 = 12
// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assum input data is in [-1.0, 1.0]
// !be careful here, its easy to make a mistake
// that is difficult to noice if you add an extra
// zero somewhere!
int
sign
=
x
<
0
?
0b1000
:
0b0000
;
x
=
fabsf
(
x
);
if
(
x
>
0.29166667
f
)
if
(
x
>
0.583333
f
)
if
(
x
>
0.8333333
f
)
return
0b0011
+
sign
;
else
return
0b0010
+
sign
;
else
if
(
x
>
0.4166667
f
)
return
0b101
+
sign
;
else
return
0b100
+
sign
;
else
if
(
x
>
0.0859375
f
)
if
(
x
>
0.20833333
f
)
return
0b0111
+
sign
;
else
return
0b0110
+
sign
;
else
if
(
x
>
0.00260417
f
)
return
0b0001
+
sign
;
else
return
0b0000
+
sign
;
}
__device__
unsigned
char
dQuantizeNormal
(
float
x
)
{
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12
int
sign
=
x
<
0
?
0b1000
:
0b0000
;
int
sign
=
x
<
0
?
0b1000
:
0b0000
;
x
=
fabsf
(
x
);
x
=
fabsf
(
x
);
if
(
x
>
3.5
f
)
if
(
x
>
3.5
f
)
{
if
(
x
>
7.0
f
)
if
(
x
>
7.0
f
)
{
if
(
x
>
10.0
f
)
if
(
x
>
10.0
f
)
return
0b0011
+
sign
;
return
0b0011
+
sign
;
else
else
return
0b0010
+
sign
;
return
0b0010
+
sign
;
}
else
else
{
if
(
x
>
5.0
f
)
if
(
x
>
5.0
f
)
return
0b101
+
sign
;
return
0b101
+
sign
;
else
else
return
0b100
+
sign
;
return
0b100
+
sign
;
}
}
else
else
{
if
(
x
>
1.03125
f
)
if
(
x
>
1.03125
f
)
{
if
(
x
>
2.5
f
)
if
(
x
>
2.5
f
)
return
0b0111
+
sign
;
return
0b0111
+
sign
;
else
else
return
0b0110
+
sign
;
return
0b0110
+
sign
;
}
else
else
{
if
(
x
>
0.03125
f
)
if
(
x
>
0.03125
f
)
return
0b0001
+
sign
;
return
0b0001
+
sign
;
else
else
return
0b0000
+
sign
;
return
0b0000
+
sign
;
}
}
}
}
template
<
int
STOCHASTIC
>
template
<
int
STOCHASTIC
>
...
@@ -575,8 +639,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
...
@@ -575,8 +639,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
for
(
int
j
=
0
;
j
<
NUM_PER_TH
/
2
;
j
++
)
{
{
unsigned
char
packed_fp4
=
0
;
unsigned
char
packed_fp4
=
0
;
packed_fp4
|=
dQuantizeFP4
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
*
12.0
f
)
<<
4
;
packed_fp4
|=
dQuantizeFP4
(((
float
)
vals
[
2
*
j
])
*
local_abs_max
)
<<
4
;
packed_fp4
|=
dQuantizeFP4
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
*
12.0
f
);
packed_fp4
|=
dQuantizeFP4
(((
float
)
vals
[
2
*
j
+
1
])
*
local_abs_max
);
qvals
[
j
]
=
packed_fp4
;
qvals
[
j
]
=
packed_fp4
;
}
}
}
}
...
@@ -639,8 +703,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
...
@@ -639,8 +703,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
#pragma unroll NUM_PER_TH
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
{
{
vals
[
j
*
2
]
=
dDequantizeFP4
(
qvals
[
j
]
>>
4
,
local_abs_max
*
0.083333
f
);
//vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
vals
[
j
*
2
+
1
]
=
dDequantizeFP4
(
qvals
[
j
]
&
0x0F
,
local_abs_max
*
0.083333
);
//vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
vals
[
j
*
2
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
>>
4
,
local_abs_max
);
vals
[
j
*
2
+
1
]
=
dDequantizeFP4Tree
(
qvals
[
j
]
&
0x0F
,
local_abs_max
);
}
}
}
}
else
else
...
@@ -656,52 +722,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
...
@@ -656,52 +722,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
}
}
}
}
//template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int TILE_SIZE>
//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store)
//{
//
// const int n_load = n_store/2;
// const int base_idx = (blockIdx.x * TILE_SIZE);
//
// T vals[NUM_PER_TH*2];
// unsigned char qvals[NUM_PER_TH];
//
// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE;
// int idx = base_idx + (threadIdx.x*NUM_PER_TH);
//
// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]);
//
// if(valid_items == TILE_SIZE)
// {
// // we do 64 byte loads so we can 128 byte stores
// reinterpret_cast<int2(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int2*>(A)[idx/8];
// }
// else
// {
// #pragma unroll
// for(int j = 0; j < NUM_PER_TH; j++)
// if(idx+j < n_load)
// qvals[j] = A[idx+j];
// else
// qvals[j] = 0;
// }
//
//
// #pragma unroll NUM_PER_TH
// for(int j = 0; j < NUM_PER_TH; j++)
// {
// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f);
// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f);
// }
//
//
// reinterpret_cast<int4(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int4*>(A)[idx/8];
// reinterpret_cast<int4*>(A)[idx/16] = reinterpret_cast<int4(&)[16]>(local_valC)[j/num_items];
//
//
//}
__global__
void
kDequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
const
int
n
)
__global__
void
kDequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
const
int
n
)
{
{
const
unsigned
int
numThreads
=
blockDim
.
x
*
gridDim
.
x
;
const
unsigned
int
numThreads
=
blockDim
.
x
*
gridDim
.
x
;
...
...
tests/test_functional.py
View file @
2dd5d690
...
@@ -2246,8 +2246,10 @@ def test_fp4_quant():
...
@@ -2246,8 +2246,10 @@ def test_fp4_quant():
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
A1
.
abs
().
float
()).
mean
()
relerr
=
(
err
/
A1
.
abs
().
float
()).
mean
()
idx
=
err
>
1.0
err
=
err
.
mean
()
err
=
err
.
mean
()
assert
err
.
item
()
<
0.1
assert
err
.
item
()
<
0.1
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
...
@@ -2256,7 +2258,7 @@ def test_fp4_compressed_stats():
...
@@ -2256,7 +2258,7 @@ def test_fp4_compressed_stats():
for
blocksize
in
[
128
,
64
]:
for
blocksize
in
[
128
,
64
]:
errs1
=
[]
errs1
=
[]
errs2
=
[]
errs2
=
[]
for
i
in
range
(
10
):
for
i
in
range
(
10
000
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
q2
,
SA2
=
F
.
quantize_fp4
(
A1
,
blocksize
=
blocksize
)
q2
,
SA2
=
F
.
quantize_fp4
(
A1
,
blocksize
=
blocksize
)
q3
,
SA3
=
F
.
quantize_fp4
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
)
q3
,
SA3
=
F
.
quantize_fp4
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
)
...
@@ -2268,7 +2270,7 @@ def test_fp4_compressed_stats():
...
@@ -2268,7 +2270,7 @@ def test_fp4_compressed_stats():
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs1
.
append
(
err
.
item
())
errs1
.
append
(
rel
err
.
item
())
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
...
@@ -2277,7 +2279,7 @@ def test_fp4_compressed_stats():
...
@@ -2277,7 +2279,7 @@ def test_fp4_compressed_stats():
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
errs2
.
append
(
err
.
item
())
errs2
.
append
(
rel
err
.
item
())
assert
err
.
item
()
<
0.11
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
...
@@ -2301,7 +2303,7 @@ def test_bench_fp4_dequant():
...
@@ -2301,7 +2303,7 @@ def test_bench_fp4_dequant():
#print(max_theoretical_s*1e6)
#print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
iters
=
5
iters
=
5
00
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment